burn-tensor
Advanced tools
| <!--Follows the instructions at--> | ||
| <!--https://docs.rs/rustdoc-katex-demo/0.1.5/rustdoc_katex_demo/--> | ||
| <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.16.22/dist/katex.min.css" integrity="sha384-5TcZemv2l/9On385z///+d7MSYlvIEw9FuZTIdZ14vJLqWphw7e7ZPuOiCHJcFCP" crossorigin="anonymous"> | ||
| <script defer src="https://cdn.jsdelivr.net/npm/katex@0.16.22/dist/katex.min.js" integrity="sha384-cMkvdD8LoxVzGF/RPUKAcvmm49FQ0oxwDF3BGKtDXcEc+T1b2N+teh/OJfpU0jr6" crossorigin="anonymous"></script> | ||
| <script defer src="https://cdn.jsdelivr.net/npm/katex@0.16.22/dist/contrib/auto-render.min.js" integrity="sha384-hCXGrW6PitJEwbkoStFjeJxv+fSOOQKOPbJxSfM6G5sWZjAyWhXiTIIAmQqnlLlh" crossorigin="anonymous"></script> | ||
| <script> | ||
| document.addEventListener("DOMContentLoaded", function() { | ||
| renderMathInElement(document.body, { | ||
| // customised options | ||
| // • auto-render specific keys, e.g.: | ||
| delimiters: [ | ||
| {left: '$$', right: '$$', display: true}, | ||
| {left: '$', right: '$', display: false}, | ||
| {left: '\\(', right: '\\)', display: false}, | ||
| {left: '\\[', right: '\\]', display: true} | ||
| ], | ||
| // • rendering keys, e.g.: | ||
| throwOnError : false | ||
| }); | ||
| }); | ||
| </script> |
| use alloc::vec::Vec; | ||
| use core::ops::Range; | ||
| use crate::{Element, ElementConversion, Tensor, backend::Backend, ops::PadMode}; | ||
| use super::Numeric; | ||
| /// Helper to build a range array for slice_assign, selecting a portion of one dimension. | ||
| fn build_slice_ranges<const D: usize>( | ||
| dims: [usize; D], | ||
| target_dim: usize, | ||
| start: usize, | ||
| len: usize, | ||
| ) -> [Range<usize>; D] { | ||
| dims.iter() | ||
| .enumerate() | ||
| .map(|(i, &size)| { | ||
| if i == target_dim { | ||
| start..start + len | ||
| } else { | ||
| 0..size | ||
| } | ||
| }) | ||
| .collect::<Vec<Range<usize>>>() | ||
| .try_into() | ||
| .unwrap() | ||
| } | ||
| impl<B, const D: usize, K> Tensor<B, D, K> | ||
| where | ||
| B: Backend, | ||
| K: Numeric<B>, | ||
| K::Elem: Element, | ||
| { | ||
| /// Pads the tensor on the last two dimensions using the specified padding mode. | ||
| /// | ||
| /// **Note**: Currently, padding is only supported on the last two dimensions of a tensor | ||
| /// (typically height and width for image data in NCHW format). | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `padding` - A tuple `(left, right, top, bottom)` specifying padding for the last two dimensions. | ||
| /// * `mode` - The padding mode: `Constant(value)`, `Reflect`, or `Edge`. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A new tensor with the specified padding applied. | ||
| /// | ||
| /// # Panics | ||
| /// | ||
| /// - `Reflect` mode panics if padding exceeds `dimension_size - 1`. | ||
| /// - `Edge` mode panics if padding is applied to a zero-sized dimension. | ||
| /// | ||
| /// # Example | ||
| /// | ||
| /// ```rust | ||
| /// use burn_tensor::backend::Backend; | ||
| /// use burn_tensor::{Tensor, Shape}; | ||
| /// use burn_tensor::ops::PadMode; | ||
| /// | ||
| /// fn example<B: Backend<FloatElem: From<f32>>>() { | ||
| /// let device = B::Device::default(); | ||
| /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device); | ||
| /// | ||
| /// // Constant padding with value 0.0 | ||
| /// let padded = tensor.clone().pad((1, 1, 1, 1), PadMode::Constant(0.0)); | ||
| /// // [ | ||
| /// // [0.0, 0.0, 0.0, 0.0, 0.0], | ||
| /// // [0.0, 12.0, -2.0, 3.0, 0.0], | ||
| /// // [0.0, 5.0, 3.0, 6.0, 0.0], | ||
| /// // [0.0, 0.0, 0.0, 0.0, 0.0] | ||
| /// // ] | ||
| /// | ||
| /// // Reflect padding | ||
| /// let padded = tensor.clone().pad((1, 1, 0, 0), PadMode::Reflect); | ||
| /// // [[−2.0, 12.0, −2.0, 3.0, −2.0], [3.0, 5.0, 3.0, 6.0, 3.0]] | ||
| /// | ||
| /// // Edge padding | ||
| /// let padded = tensor.pad((1, 1, 0, 0), PadMode::Edge); | ||
| /// // [[12.0, 12.0, −2.0, 3.0, 3.0], [5.0, 5.0, 3.0, 6.0, 6.0]] | ||
| /// } | ||
| /// ``` | ||
| pub fn pad(self, padding: (usize, usize, usize, usize), mode: PadMode) -> Self { | ||
| match mode { | ||
| PadMode::Constant(value) => pad_constant(self, padding, value), | ||
| PadMode::Reflect => pad_reflect(self, padding), | ||
| PadMode::Edge => pad_edge(self, padding), | ||
| } | ||
| } | ||
| } | ||
| /// Pad with a constant value. | ||
| pub fn pad_constant<B, const D: usize, K, E>( | ||
| tensor: Tensor<B, D, K>, | ||
| padding: (usize, usize, usize, usize), | ||
| value: E, | ||
| ) -> Tensor<B, D, K> | ||
| where | ||
| B: Backend, | ||
| K: Numeric<B>, | ||
| K::Elem: Element, | ||
| E: ElementConversion, | ||
| { | ||
| let (left, right, top, bottom) = padding; | ||
| let mut padded_dims: [usize; D] = tensor.dims(); | ||
| // Update the last two dimensions with padding | ||
| padded_dims[D - 2] += top + bottom; | ||
| padded_dims[D - 1] += left + right; | ||
| // Create the ranges for the padded tensor | ||
| let ranges: [core::ops::Range<usize>; D] = padded_dims | ||
| .iter() | ||
| .enumerate() | ||
| .map(|(i, &dim)| { | ||
| if i == D - 2 { | ||
| top..dim - bottom | ||
| } else if i == D - 1 { | ||
| left..dim - right | ||
| } else { | ||
| 0..dim | ||
| } | ||
| }) | ||
| .collect::<Vec<core::ops::Range<usize>>>() | ||
| .try_into() | ||
| .unwrap(); | ||
| // Create the padded tensor | ||
| let padded_tensor = Tensor::full(padded_dims, value, &tensor.device()); | ||
| // Assign the original tensor data to the appropriate slice of the padded tensor | ||
| padded_tensor.slice_assign(ranges, tensor) | ||
| } | ||
| /// Pad using reflection at the boundaries (excluding edge values). | ||
| /// | ||
| /// For ONNX "reflect" mode: mirrors from index 1, not index 0. | ||
| /// Example: `[1, 2, 3, 4]` with left padding 2 becomes `[3, 2, 1, 2, 3, 4]` | ||
| pub fn pad_reflect<B, const D: usize, K>( | ||
| tensor: Tensor<B, D, K>, | ||
| padding: (usize, usize, usize, usize), | ||
| ) -> Tensor<B, D, K> | ||
| where | ||
| B: Backend, | ||
| K: Numeric<B>, | ||
| K::Elem: Element, | ||
| { | ||
| let (left, right, top, bottom) = padding; | ||
| let dims = tensor.dims(); | ||
| // Validate padding doesn't exceed tensor dimensions | ||
| // For reflect mode, padding must be less than the corresponding dimension | ||
| assert!( | ||
| top < dims[D - 2] && bottom < dims[D - 2], | ||
| "Reflect padding on height ({}, {}) must be less than height dimension ({})", | ||
| top, | ||
| bottom, | ||
| dims[D - 2] | ||
| ); | ||
| assert!( | ||
| left < dims[D - 1] && right < dims[D - 1], | ||
| "Reflect padding on width ({}, {}) must be less than width dimension ({})", | ||
| left, | ||
| right, | ||
| dims[D - 1] | ||
| ); | ||
| let mut result = tensor; | ||
| // Pad height dimension (D - 2): top and bottom | ||
| if top > 0 || bottom > 0 { | ||
| result = pad_reflect_dim(result, D - 2, top, bottom); | ||
| } | ||
| // Pad width dimension (D - 1): left and right | ||
| if left > 0 || right > 0 { | ||
| result = pad_reflect_dim(result, D - 1, left, right); | ||
| } | ||
| result | ||
| } | ||
| /// Helper to pad a single dimension using reflection. | ||
| fn pad_reflect_dim<B, const D: usize, K>( | ||
| tensor: Tensor<B, D, K>, | ||
| dim: usize, | ||
| pad_before: usize, | ||
| pad_after: usize, | ||
| ) -> Tensor<B, D, K> | ||
| where | ||
| B: Backend, | ||
| K: Numeric<B>, | ||
| K::Elem: Element, | ||
| { | ||
| let dims = tensor.dims(); | ||
| let dim_size = dims[dim]; | ||
| // Calculate output dimensions | ||
| let mut output_dims = dims; | ||
| output_dims[dim] += pad_before + pad_after; | ||
| // Create output tensor and place original in the center | ||
| let output = Tensor::zeros(output_dims, &tensor.device()); | ||
| let original_range = build_slice_ranges(output_dims, dim, pad_before, dim_size); | ||
| let mut output = output.slice_assign(original_range, tensor.clone()); | ||
| // Assign reflected "before" padding (e.g., top or left) | ||
| // Reflect excludes the edge, so we take indices [1..pad_before+1] and flip | ||
| if pad_before > 0 { | ||
| let before_slice = tensor.clone().narrow(dim, 1, pad_before); | ||
| let before_flipped = before_slice.flip([dim as isize]); | ||
| let before_range = build_slice_ranges(output_dims, dim, 0, pad_before); | ||
| output = output.slice_assign(before_range, before_flipped); | ||
| } | ||
| // Assign reflected "after" padding (e.g., bottom or right) | ||
| // Take indices [dim_size - pad_after - 1..dim_size - 1] and flip | ||
| if pad_after > 0 { | ||
| let start = dim_size - pad_after - 1; | ||
| let after_slice = tensor.narrow(dim, start, pad_after); | ||
| let after_flipped = after_slice.flip([dim as isize]); | ||
| let after_range = build_slice_ranges(output_dims, dim, pad_before + dim_size, pad_after); | ||
| output = output.slice_assign(after_range, after_flipped); | ||
| } | ||
| output | ||
| } | ||
| /// Pad by replicating edge values. | ||
| /// | ||
| /// Example: `[1, 2, 3, 4]` with left padding 2 becomes `[1, 1, 1, 2, 3, 4]` | ||
| pub fn pad_edge<B, const D: usize, K>( | ||
| tensor: Tensor<B, D, K>, | ||
| padding: (usize, usize, usize, usize), | ||
| ) -> Tensor<B, D, K> | ||
| where | ||
| B: Backend, | ||
| K: Numeric<B>, | ||
| K::Elem: Element, | ||
| { | ||
| let (left, right, top, bottom) = padding; | ||
| let dims = tensor.dims(); | ||
| // Validate dimensions are non-zero when padding is requested | ||
| if top > 0 || bottom > 0 { | ||
| assert!( | ||
| dims[D - 2] > 0, | ||
| "Cannot apply edge padding to zero-sized height dimension" | ||
| ); | ||
| } | ||
| if left > 0 || right > 0 { | ||
| assert!( | ||
| dims[D - 1] > 0, | ||
| "Cannot apply edge padding to zero-sized width dimension" | ||
| ); | ||
| } | ||
| let mut result = tensor; | ||
| // Pad height dimension (D - 2): top and bottom | ||
| if top > 0 || bottom > 0 { | ||
| result = pad_edge_dim(result, D - 2, top, bottom); | ||
| } | ||
| // Pad width dimension (D - 1): left and right | ||
| if left > 0 || right > 0 { | ||
| result = pad_edge_dim(result, D - 1, left, right); | ||
| } | ||
| result | ||
| } | ||
| /// Helper to pad a single dimension by replicating edge values. | ||
| fn pad_edge_dim<B, const D: usize, K>( | ||
| tensor: Tensor<B, D, K>, | ||
| dim: usize, | ||
| pad_before: usize, | ||
| pad_after: usize, | ||
| ) -> Tensor<B, D, K> | ||
| where | ||
| B: Backend, | ||
| K: Numeric<B>, | ||
| K::Elem: Element, | ||
| { | ||
| let dims = tensor.dims(); | ||
| let dim_size = dims[dim]; | ||
| // Calculate output dimensions | ||
| let mut output_dims = dims; | ||
| output_dims[dim] += pad_before + pad_after; | ||
| // Create output tensor and place original in the center | ||
| let output = Tensor::zeros(output_dims, &tensor.device()); | ||
| let original_range = build_slice_ranges(output_dims, dim, pad_before, dim_size); | ||
| let mut output = output.slice_assign(original_range, tensor.clone()); | ||
| // Assign "before" padding by repeating the first element | ||
| if pad_before > 0 { | ||
| let first_slice = tensor.clone().narrow(dim, 0, 1); | ||
| let before_pad = first_slice.repeat_dim(dim, pad_before); | ||
| let before_range = build_slice_ranges(output_dims, dim, 0, pad_before); | ||
| output = output.slice_assign(before_range, before_pad); | ||
| } | ||
| // Assign "after" padding by repeating the last element | ||
| if pad_after > 0 { | ||
| let last_slice = tensor.narrow(dim, dim_size - 1, 1); | ||
| let after_pad = last_slice.repeat_dim(dim, pad_after); | ||
| let after_range = build_slice_ranges(output_dims, dim, pad_before + dim_size, pad_after); | ||
| output = output.slice_assign(after_range, after_pad); | ||
| } | ||
| output | ||
| } |
| use crate::{Tensor, TensorPrimitive, backend::Backend}; | ||
| // We re-export those types. | ||
| pub use burn_backend::{ | ||
| QTensorPrimitive, | ||
| tensor::{Calibration, QParamTensor, QParams, QuantizationParametersPrimitive, params_shape}, | ||
| }; | ||
| pub use burn_std::quantization::{ | ||
| BlockSize, QuantLevel, QuantMode, QuantParam, QuantPropagation, QuantScheme, QuantStore, | ||
| QuantValue, QuantizedBytes, | ||
| }; | ||
| /// The tensor quantization parameters. | ||
| pub type QuantizationParameters<B> = QParams<Tensor<B, 1>>; | ||
| /// The observed input calibration range. | ||
| #[derive(Clone, Debug)] | ||
| pub struct CalibrationRange<B: Backend> { | ||
| /// Minimum observed value(s). | ||
| pub min: Tensor<B, 1>, | ||
| /// Maximum observed value(s). | ||
| pub max: Tensor<B, 1>, | ||
| } | ||
| /// Compute the quantization range mapping. | ||
| pub fn compute_range<B: Backend, const D: usize>( | ||
| scheme: &QuantScheme, | ||
| tensor: &Tensor<B, D>, | ||
| calibration: &Calibration, | ||
| ) -> CalibrationRange<B> { | ||
| let (min, max) = match &tensor.primitive { | ||
| TensorPrimitive::Float(tensor) => { | ||
| burn_backend::tensor::compute_range::<B>(scheme, tensor.clone(), calibration) | ||
| } | ||
| TensorPrimitive::QFloat(_) => unreachable!(), | ||
| }; | ||
| CalibrationRange { | ||
| min: Tensor::from_primitive(TensorPrimitive::Float(min)), | ||
| max: Tensor::from_primitive(TensorPrimitive::Float(max)), | ||
| } | ||
| } | ||
| /// Compute the quantization parameters. | ||
| pub fn compute_q_params<B: Backend>( | ||
| scheme: &QuantScheme, | ||
| range: CalibrationRange<B>, | ||
| ) -> QuantizationParameters<B> { | ||
| match (range.min.primitive, range.max.primitive) { | ||
| (TensorPrimitive::Float(min), TensorPrimitive::Float(max)) => { | ||
| let qparams = burn_backend::tensor::compute_q_params::<B>(scheme, min, max); | ||
| QuantizationParameters { | ||
| scales: Tensor::from_primitive(TensorPrimitive::Float(qparams.scales)), | ||
| } | ||
| } | ||
| _ => unreachable!(), | ||
| } | ||
| } |
| #[burn_tensor_testgen::testgen(module_attention)] | ||
| mod tests { | ||
| use super::*; | ||
| use burn_tensor::Distribution; | ||
| use burn_tensor::TensorData; | ||
| use burn_tensor::module::attention; | ||
| use burn_tensor::module::naive_attention; | ||
| use burn_tensor::{Tolerance, ops::FloatElem}; | ||
| type FT = FloatElem<TestBackend>; | ||
| #[test] | ||
| fn test_attention_no_mask() { | ||
| let num_batches = 1; | ||
| let num_heads = 1; | ||
| let seq_q = 128; | ||
| let seq_kv = 128; | ||
| let head_dim = 64; | ||
| let val_dim = 64; | ||
| let query = TestTensor::<4>::random( | ||
| [num_batches, num_heads, seq_q, head_dim], | ||
| Distribution::Uniform(0., 1.), | ||
| &Default::default(), | ||
| ); | ||
| let key = TestTensor::<4>::random( | ||
| [num_batches, num_heads, seq_kv, head_dim], | ||
| Distribution::Uniform(0., 1.), | ||
| &Default::default(), | ||
| ); | ||
| let value = TestTensor::<4>::random( | ||
| [num_batches, num_heads, seq_kv, val_dim], | ||
| Distribution::Uniform(0., 1.), | ||
| &Default::default(), | ||
| ); | ||
| let output = attention(query.clone(), key.clone(), value.clone(), None); | ||
| let expected = naive_attention::<TestBackend>(query, key, value, None); | ||
| output | ||
| .into_data() | ||
| .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::relative(1e-2)); | ||
| } | ||
| } |
| { | ||
| "git": { | ||
| "sha1": "0368cc660dc9fc084292795ffe6b4e060d5aa668" | ||
| "sha1": "42edc63ecfb5b02e606c8789b6c6fed867bd6cc5" | ||
| }, | ||
| "path_in_vcs": "crates/burn-tensor" | ||
| } |
+10
-23
@@ -15,3 +15,3 @@ # THIS FILE IS AUTOMATICALLY GENERATED BY CARGO | ||
| name = "burn-tensor" | ||
| version = "0.20.0-pre.4" | ||
| version = "0.20.0-pre.5" | ||
| authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"] | ||
@@ -55,22 +55,9 @@ build = false | ||
| cubecl = [ | ||
| "dep:cubecl", | ||
| "burn-std/cubecl", | ||
| "burn-backend/cubecl", | ||
| ] | ||
| cubecl-cpu = [ | ||
| "cubecl", | ||
| "cubecl/cpu", | ||
| ] | ||
| cubecl-cuda = [ | ||
| "cubecl", | ||
| "cubecl/cuda", | ||
| ] | ||
| cubecl-hip = [ | ||
| "cubecl", | ||
| "cubecl/hip", | ||
| ] | ||
| cubecl-wgpu = [ | ||
| "cubecl", | ||
| "cubecl/wgpu", | ||
| ] | ||
| cubecl-cpu = ["burn-backend/cubecl-cpu"] | ||
| cubecl-cuda = ["burn-backend/cubecl-cuda"] | ||
| cubecl-hip = ["burn-backend/cubecl-hip"] | ||
| cubecl-wgpu = ["burn-backend/cubecl-wgpu"] | ||
| default = ["std"] | ||
@@ -81,3 +68,3 @@ doc = ["default"] | ||
| "burn-tensor-testgen", | ||
| "cubecl", | ||
| "dep:cubecl", | ||
| ] | ||
@@ -96,11 +83,11 @@ std = [ | ||
| [dependencies.burn-backend] | ||
| version = "0.20.0-pre.4" | ||
| version = "0.20.0-pre.5" | ||
| default-features = false | ||
| [dependencies.burn-std] | ||
| version = "0.20.0-pre.4" | ||
| version = "0.20.0-pre.5" | ||
| default-features = false | ||
| [dependencies.burn-tensor-testgen] | ||
| version = "0.20.0-pre.4" | ||
| version = "0.20.0-pre.5" | ||
| optional = true | ||
@@ -113,3 +100,3 @@ | ||
| [dependencies.cubecl] | ||
| version = "=0.9.0-pre.4" | ||
| version = "=0.9.0-pre.5" | ||
| optional = true | ||
@@ -116,0 +103,0 @@ default-features = false |
+0
-32
@@ -32,33 +32,1 @@ #![cfg_attr(not(feature = "std"), no_std)] | ||
| pub use burn_std::{Bytes, bf16, f16}; | ||
| #[cfg(feature = "cubecl-wgpu")] | ||
| mod cube_wgpu { | ||
| use crate::backend::DeviceOps; | ||
| use cubecl::wgpu::WgpuDevice; | ||
| impl DeviceOps for WgpuDevice {} | ||
| } | ||
| #[cfg(feature = "cubecl-cuda")] | ||
| mod cube_cuda { | ||
| use crate::backend::DeviceOps; | ||
| use cubecl::cuda::CudaDevice; | ||
| impl DeviceOps for CudaDevice {} | ||
| } | ||
| #[cfg(all(feature = "cubecl-cpu", target_os = "linux"))] | ||
| mod cube_cpu { | ||
| use crate::backend::DeviceOps; | ||
| use cubecl::cpu::CpuDevice; | ||
| impl DeviceOps for CpuDevice {} | ||
| } | ||
| #[cfg(feature = "cubecl-hip")] | ||
| mod cube_hip { | ||
| use crate::backend::DeviceOps; | ||
| use cubecl::hip::AmdDevice; | ||
| impl DeviceOps for AmdDevice {} | ||
| } |
@@ -197,2 +197,4 @@ use crate::backend::Backend; | ||
| /// | ||
| /// Also referred to as [`softmax1`](https://www.evanmiller.org/attention-is-off-by-one.html). | ||
| /// | ||
| /// This function is similar to the softmax function, but it allows for "no selection" when | ||
@@ -222,7 +224,7 @@ /// all the outputs are close to zero. | ||
| let tensor = tensor.clone() - tensor.detach().max_dim(dim); | ||
| let tensor = tensor.exp(); | ||
| let tensor_tmp = tensor.clone().sum_dim(dim); | ||
| let max_vals = tensor.clone().detach().max_dim(dim); | ||
| let exp_x = (tensor - max_vals.clone()).exp(); | ||
| let sum_exp = exp_x.clone().sum_dim(dim); | ||
| tensor.div(tensor_tmp + 1) | ||
| exp_x.div(sum_exp + max_vals.neg().exp()) | ||
| } | ||
@@ -381,8 +383,7 @@ | ||
| let new_len = tensor.dims()[dim] / 2; | ||
| // The `s!` macro is used for slicing tensors along a specific dimension. | ||
| // Usage: s![dim, start..end] slices the tensor along `dim` from `start` to `end` (exclusive). | ||
| let a = tensor.clone().slice(s![dim, 0..new_len]); | ||
| let b = tensor.slice(s![dim, new_len..new_len * 2]); | ||
| let a = tensor.clone().slice_dim(dim, s![0..new_len]); | ||
| let b = tensor.slice_dim(dim, s![new_len..new_len * 2]); | ||
| a.mul(sigmoid(b)) | ||
| } |
@@ -1,5 +0,5 @@ | ||
| use crate::{ | ||
| BasicOps, Bool, Float, Int, Tensor, TensorKind, TensorPrimitive, backend::AutodiffBackend, | ||
| }; | ||
| pub use burn_backend::tensor::BasicAutodiffOps; | ||
| use crate::{Tensor, TensorPrimitive, backend::AutodiffBackend}; | ||
| impl<const D: usize, B: AutodiffBackend> Tensor<B, D> { | ||
@@ -76,93 +76,1 @@ /// Backward pass of the tensor. | ||
| } | ||
| impl<B: AutodiffBackend> BasicAutodiffOps<B> for Float { | ||
| type InnerKind = Float; | ||
| fn inner( | ||
| tensor: <Self as TensorKind<B>>::Primitive, | ||
| ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive { | ||
| match tensor { | ||
| TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::inner(tensor)), | ||
| TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_inner(tensor)), | ||
| } | ||
| } | ||
| fn from_inner( | ||
| inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive, | ||
| ) -> <Self as TensorKind<B>>::Primitive { | ||
| match inner { | ||
| TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::from_inner(tensor)), | ||
| TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_from_inner(tensor)), | ||
| } | ||
| } | ||
| } | ||
| impl<B: AutodiffBackend> BasicAutodiffOps<B> for Int { | ||
| type InnerKind = Int; | ||
| fn inner( | ||
| tensor: <Self as TensorKind<B>>::Primitive, | ||
| ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive { | ||
| B::int_inner(tensor) | ||
| } | ||
| fn from_inner( | ||
| inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive, | ||
| ) -> <Self as TensorKind<B>>::Primitive { | ||
| B::int_from_inner(inner) | ||
| } | ||
| } | ||
| impl<B: AutodiffBackend> BasicAutodiffOps<B> for Bool { | ||
| type InnerKind = Bool; | ||
| fn inner( | ||
| tensor: <Self as TensorKind<B>>::Primitive, | ||
| ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive { | ||
| B::bool_inner(tensor) | ||
| } | ||
| fn from_inner( | ||
| inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive, | ||
| ) -> <Self as TensorKind<B>>::Primitive { | ||
| B::bool_from_inner(inner) | ||
| } | ||
| } | ||
| /// Trait that list all operations that can be applied on all tensors on an autodiff backend. | ||
| /// | ||
| /// # Warnings | ||
| /// | ||
| /// This is an internal trait, use the public API provided by [tensor struct](Tensor). | ||
| pub trait BasicAutodiffOps<B: AutodiffBackend>: BasicOps<B> + BasicOps<B::InnerBackend> { | ||
| /// Inner primitive tensor. | ||
| type InnerKind: BasicOps<B::InnerBackend>; | ||
| /// Returns the inner tensor without the autodiff information. | ||
| /// | ||
| /// # Remarks | ||
| /// | ||
| /// This is a low-level function used internally by the library to call different backend functions | ||
| /// with static dispatch. It is not designed for direct usage by users, and not recommended to import | ||
| /// or use this function directly. | ||
| /// | ||
| /// Users should prefer the [Tensor::inner](Tensor::inner) function, | ||
| /// which is more high-level and designed for public use. | ||
| fn inner( | ||
| tensor: <Self as TensorKind<B>>::Primitive, | ||
| ) -> <Self::InnerKind as TensorKind<B::InnerBackend>>::Primitive; | ||
| /// Convert a tensor to the autodiff backend. | ||
| /// | ||
| /// # Remarks | ||
| /// | ||
| /// This is a low-level function used internally by the library to call different backend functions | ||
| /// with static dispatch. It is not designed for direct usage by users, and not recommended to import | ||
| /// or use this function directly. | ||
| /// | ||
| /// Users should prefer the [Tensor::from_inner](Tensor::from_inner) function, | ||
| /// which is more high-level and designed for public use. | ||
| fn from_inner( | ||
| inner: <Self::InnerKind as TensorKind<B::InnerBackend>>::Primitive, | ||
| ) -> <Self as TensorKind<B>>::Primitive; | ||
| } |
@@ -905,15 +905,4 @@ use crate::ops::FloatElem; | ||
| if range.start >= range.end && slice.step > 0 { | ||
| check = check.register( | ||
| "Slice Assign", | ||
| TensorError::new( | ||
| "The provided slice has a range where the start index is bigger or \ | ||
| equal to its end with positive step.", | ||
| ) | ||
| .details(format!( | ||
| "The range start ({}) must be smaller than its end ({}) for positive step ({}) at dimension {}", | ||
| range.start, range.end, slice.step, i | ||
| )), | ||
| ); | ||
| } | ||
| // Note: Empty slices (start >= end with positive step) are handled at the API level | ||
| // by returning the original tensor unchanged, so we don't check for them here. | ||
| } | ||
@@ -920,0 +909,0 @@ |
@@ -0,1 +1,3 @@ | ||
| use burn_backend::tensor::QuantizationParametersPrimitive; | ||
| use crate::AsIndex; | ||
@@ -13,6 +15,4 @@ use crate::FloatDType; | ||
| use crate::tensor::{Distribution, TensorData}; | ||
| use crate::{Int, TensorPrimitive}; | ||
| use crate::{Bool, Int, TensorPrimitive}; | ||
| use super::Bool; | ||
| /// Default RTOL value for `is_close` and `all_close`. | ||
@@ -406,3 +406,5 @@ pub const DEFAULT_RTOL: f64 = 1e-5; | ||
| scheme, | ||
| qparams.into(), | ||
| QuantizationParametersPrimitive { | ||
| scales: qparams.scales.primitive.tensor(), | ||
| }, | ||
| ))) | ||
@@ -409,0 +411,0 @@ } |
| pub(crate) mod check; | ||
| mod argwhere; | ||
| mod autodiff; | ||
@@ -11,5 +10,4 @@ mod base; | ||
| mod int; | ||
| mod kind; | ||
| mod numeric; | ||
| mod sort; | ||
| mod pad; | ||
| mod take; | ||
@@ -19,3 +17,2 @@ mod transaction; | ||
| pub use argwhere::argwhere_data; | ||
| pub use autodiff::*; | ||
@@ -25,5 +22,5 @@ pub use base::*; | ||
| pub use float::{DEFAULT_ATOL, DEFAULT_RTOL}; | ||
| pub use kind::*; | ||
| pub use numeric::*; | ||
| pub use sort::{argsort, sort, sort_with_indices}; | ||
| pub use transaction::*; | ||
| pub use burn_backend::tensor::IndexingUpdateOp; |
@@ -1,6 +0,6 @@ | ||
| use super::{BasicOps, Tensor, TensorPrimitive}; | ||
| use super::{BasicOps, Tensor}; | ||
| use crate::{ | ||
| TensorData, | ||
| backend::{Backend, ExecutionError}, | ||
| ops::{BoolTensor, IntTensor, TransactionPrimitive}, | ||
| ops::TransactionPrimitive, | ||
| }; | ||
@@ -26,16 +26,8 @@ use alloc::vec::Vec; | ||
| op: TransactionPrimitive<B>, | ||
| orders: Vec<Order>, | ||
| } | ||
| enum Order { | ||
| Float(usize), | ||
| QFloat(usize), | ||
| Int(usize), | ||
| Bool(usize), | ||
| } | ||
| impl<B: Backend> Transaction<B> { | ||
| /// Add a [tensor](Tensor) to the transaction to be read. | ||
| pub fn register<const D: usize, K: BasicOps<B>>(mut self, tensor: Tensor<B, D, K>) -> Self { | ||
| K::register_transaction(&mut self, tensor.into_primitive()); | ||
| K::register_transaction(&mut self.op, tensor.into_primitive()); | ||
| self | ||
@@ -64,43 +56,4 @@ } | ||
| pub async fn execute_async(self) -> Result<Vec<TensorData>, ExecutionError> { | ||
| let result = B::tr_execute(self.op).await?; | ||
| let mut floats: Vec<_> = result.read_floats.into_iter().map(Some).collect(); | ||
| let mut qfloats: Vec<_> = result.read_qfloats.into_iter().map(Some).collect(); | ||
| let mut ints: Vec<_> = result.read_ints.into_iter().map(Some).collect(); | ||
| let mut bools: Vec<_> = result.read_bools.into_iter().map(Some).collect(); | ||
| Ok(self | ||
| .orders | ||
| .into_iter() | ||
| .map(|order| match order { | ||
| Order::Float(index) => floats.get_mut(index).unwrap().take().unwrap(), | ||
| Order::QFloat(index) => qfloats.get_mut(index).unwrap().take().unwrap(), | ||
| Order::Int(index) => ints.get_mut(index).unwrap().take().unwrap(), | ||
| Order::Bool(index) => bools.get_mut(index).unwrap().take().unwrap(), | ||
| }) | ||
| .collect::<Vec<_>>()) | ||
| self.op.execute_async().await | ||
| } | ||
| pub(crate) fn register_float(&mut self, tensor: TensorPrimitive<B>) { | ||
| match tensor { | ||
| TensorPrimitive::Float(tensor) => { | ||
| self.orders.push(Order::Float(self.op.read_floats.len())); | ||
| self.op.read_floats.push(tensor); | ||
| } | ||
| TensorPrimitive::QFloat(tensor) => { | ||
| self.orders.push(Order::QFloat(self.op.read_qfloats.len())); | ||
| self.op.read_qfloats.push(tensor); | ||
| } | ||
| } | ||
| } | ||
| pub(crate) fn register_int(&mut self, tensor: IntTensor<B>) { | ||
| self.orders.push(Order::Int(self.op.read_ints.len())); | ||
| self.op.read_ints.push(tensor); | ||
| } | ||
| pub(crate) fn register_bool(&mut self, tensor: BoolTensor<B>) { | ||
| self.orders.push(Order::Bool(self.op.read_bools.len())); | ||
| self.op.read_bools.push(tensor); | ||
| } | ||
| } |
+17
-4
@@ -8,3 +8,9 @@ pub(crate) mod stats; | ||
| // Re-exported types | ||
| pub use burn_backend::{DataError, TensorData, Tolerance, distribution::*, element::*}; | ||
| pub use burn_backend::{ | ||
| DataError, TensorData, TensorMetadata, TensorPrimitive, Tolerance, | ||
| distribution::*, | ||
| element::*, | ||
| ops::TransactionPrimitive, | ||
| tensor::{Bool, Float, Int, TensorKind}, | ||
| }; | ||
| pub use burn_std::{ | ||
@@ -19,3 +25,5 @@ DType, FloatDType, IntDType, s, | ||
| /// The backend module. | ||
| pub mod backend; | ||
| pub mod backend { | ||
| pub use burn_backend::backend::*; | ||
| } | ||
@@ -34,7 +42,12 @@ /// The container module. | ||
| /// The burn module. | ||
| /// The neural network module. | ||
| pub mod module; | ||
| /// Operations on tensors module. | ||
| pub mod ops; | ||
| pub mod ops { | ||
| pub use burn_backend::backend::ops::*; | ||
| pub use burn_backend::tensor::{ | ||
| BoolTensor, Device, FloatElem, FloatTensor, IntElem, IntTensor, QuantizedTensor, | ||
| }; | ||
| } | ||
@@ -41,0 +54,0 @@ /// Tensor quantization module. |
+50
-1
| use crate::{ | ||
| Int, Tensor, TensorPrimitive, | ||
| Bool, Int, Tensor, TensorPrimitive, | ||
| backend::Backend, | ||
@@ -410,1 +410,50 @@ check, | ||
| } | ||
| /// Computes scaled dot-product attention: softmax(QKᵗ / √d) · V, | ||
| /// optionally applying a 4D mask to the attention scores. | ||
| /// | ||
| /// # Arguments | ||
| /// - `query`: Query tensor of shape `[batch_size, seq_len_q, num_heads, head_dim]` | ||
| /// - `key`: Key tensor of shape `[batch_size, seq_len_k, num_heads, head_dim]` | ||
| /// - `value`: Value tensor of shape `[batch_size, seq_len_k, num_heads, head_dim]` | ||
| /// - `mask`: Optional boolean mask of shape `[batch_size, seq_len_q, num_heads, seq_len_k]`, | ||
| /// where `true` indicates positions to mask (i.e. set to -∞ before softmax). | ||
| /// | ||
| /// # Returns | ||
| /// A tensor of shape `[batch_size, seq_len_q, num_heads, head_dim]` | ||
| /// representing the attended context per head. | ||
| /// | ||
| /// # Note | ||
| /// This implementation does not support dropout and is intended for inference or | ||
| /// use cases where dropout is not needed. | ||
| pub fn attention<B: Backend>( | ||
| query: Tensor<B, 4>, | ||
| key: Tensor<B, 4>, | ||
| value: Tensor<B, 4>, | ||
| mask: Option<Tensor<B, 4, Bool>>, | ||
| ) -> Tensor<B, 3> { | ||
| Tensor::new(TensorPrimitive::Float(B::attention( | ||
| query.primitive.tensor(), | ||
| key.primitive.tensor(), | ||
| value.primitive.tensor(), | ||
| mask.map(|mask| mask.primitive), | ||
| ))) | ||
| } | ||
| #[cfg(feature = "export_tests")] | ||
| /// Exports naive attention to test backend's attention against | ||
| pub fn naive_attention<B: Backend>( | ||
| query: Tensor<B, 4>, | ||
| key: Tensor<B, 4>, | ||
| value: Tensor<B, 4>, | ||
| mask: Option<Tensor<B, 4, Bool>>, | ||
| ) -> Tensor<B, 3> { | ||
| Tensor::new(TensorPrimitive::Float( | ||
| crate::ops::attention::naive_attention::<B>( | ||
| query.primitive.tensor(), | ||
| key.primitive.tensor(), | ||
| value.primitive.tensor(), | ||
| mask.map(|mask| mask.primitive), | ||
| ), | ||
| )) | ||
| } |
@@ -22,3 +22,3 @@ #[burn_tensor_testgen::testgen(glu)] | ||
| output.into_data().assert_eq( | ||
| output.into_data().assert_approx_eq::<FloatType>( | ||
| &TensorData::from([[ | ||
@@ -29,5 +29,5 @@ [-0.2665, -0.2487, 0.6656, -0.2904], | ||
| ]]), | ||
| false, | ||
| Default::default(), | ||
| ); | ||
| } | ||
| } |
@@ -8,2 +8,3 @@ pub(crate) mod gelu; | ||
| pub(crate) mod prelu; | ||
| pub(crate) mod quiet_softmax; | ||
| pub(crate) mod relu; | ||
@@ -10,0 +11,0 @@ pub(crate) mod sigmoid; |
| #[burn_tensor_testgen::testgen(quiet_softmax)] | ||
| mod tests { | ||
| use super::*; | ||
| use burn_tensor::{activation, Tensor, TensorData}; | ||
| use burn_tensor::{Tolerance, ops::FloatElem}; | ||
| type FT = FloatElem<TestBackend>; | ||
| use burn_tensor::{Tensor, TensorData, activation}; | ||
| use burn_tensor::{Tolerance, ops::FloatElem}; | ||
| type FT = FloatElem<TestBackend>; | ||
| #[test] | ||
| fn test_quiet_softmax_d2() { | ||
| let tensor = TestTensor::from([[1.0, 7.0], [13.0, -3.0]]); | ||
| let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]); | ||
@@ -15,4 +15,6 @@ let output = activation::quiet_softmax(tensor, 1); | ||
| output.into_data().assert_approx_eq::<FT>(&expected, Tolerance::default()); | ||
| output | ||
| .into_data() | ||
| .assert_approx_eq::<FT>(&expected, Tolerance::default()); | ||
| } | ||
| } |
@@ -13,3 +13,3 @@ #[burn_tensor_testgen::testgen(clone_invariance)] | ||
| }; | ||
| use burn_tensor::{Distribution, Tensor, TensorData}; | ||
| use burn_tensor::{Distribution, IndexingUpdateOp, Tensor, TensorData}; | ||
| use burn_tensor::{Tolerance, ops::FloatElem}; | ||
@@ -511,3 +511,3 @@ type FT = FloatElem<TestBackend>; | ||
| let indices = TestTensorInt::ones(shape, &Default::default()); | ||
| tensor.scatter(0, indices, values) | ||
| tensor.scatter(0, indices, values, IndexingUpdateOp::Add) | ||
| } | ||
@@ -533,3 +533,3 @@ ); | ||
| let values = values.select(0, indices.clone()); | ||
| tensor.select_assign(0, indices, values) | ||
| tensor.select_assign(0, indices, values, IndexingUpdateOp::Add) | ||
| } | ||
@@ -755,3 +755,3 @@ ); | ||
| let indices = TestTensorInt::ones(shape, &Default::default()); | ||
| tensor.scatter(0, indices, values) | ||
| tensor.scatter(0, indices, values, IndexingUpdateOp::Add) | ||
| } | ||
@@ -777,3 +777,3 @@ ); | ||
| let values = values.select(0, indices.clone()); | ||
| tensor.select_assign(0, indices, values) | ||
| tensor.select_assign(0, indices, values, IndexingUpdateOp::Add) | ||
| } | ||
@@ -780,0 +780,0 @@ ); |
+3
-0
@@ -151,2 +151,3 @@ mod activation; | ||
| burn_tensor::testgen_gelu!(); | ||
| burn_tensor::testgen_glu!(); | ||
| burn_tensor::testgen_mish!(); | ||
@@ -162,2 +163,3 @@ burn_tensor::testgen_relu!(); | ||
| burn_tensor::testgen_tanh_activation!(); | ||
| burn_tensor::testgen_quiet_softmax!(); | ||
@@ -197,2 +199,3 @@ // test grid | ||
| burn_tensor::testgen_module_linear!(); | ||
| burn_tensor::testgen_module_attention!(); | ||
@@ -199,0 +202,0 @@ // test ops |
| mod adaptive_avgpool1d; | ||
| mod adaptive_avgpool2d; | ||
| mod attention; | ||
| mod avgpool1d; | ||
@@ -4,0 +5,0 @@ mod avgpool2d; |
+85
-0
@@ -120,2 +120,87 @@ #[burn_tensor_testgen::testgen(cat)] | ||
| } | ||
| #[test] | ||
| fn should_support_cat_with_empty_tensor() { | ||
| let device = Default::default(); | ||
| let tensor_1 = TestTensor::<2>::from_data([[1.0, 2.0, 3.0]], &device); | ||
| let tensor_2: TestTensor<2> = TestTensor::empty([1, 0], &device); // Empty tensor with size 0 on dim 1 | ||
| // Concatenating with an empty tensor should just return the non-empty tensor | ||
| let output = TestTensor::cat(vec![tensor_1.clone(), tensor_2], 1); | ||
| let expected = TensorData::from([[1.0, 2.0, 3.0]]); | ||
| output | ||
| .into_data() | ||
| .assert_approx_eq::<FT>(&expected, Tolerance::default()); | ||
| } | ||
| #[test] | ||
| fn should_support_cat_with_empty_tensor_first() { | ||
| let device = Default::default(); | ||
| let tensor_1: TestTensor<2> = TestTensor::empty([1, 0], &device); // Empty tensor | ||
| let tensor_2 = TestTensor::<2>::from_data([[4.0, 5.0, 6.0]], &device); | ||
| // Empty tensor first, then non-empty | ||
| let output = TestTensor::cat(vec![tensor_1, tensor_2.clone()], 1); | ||
| let expected = TensorData::from([[4.0, 5.0, 6.0]]); | ||
| output | ||
| .into_data() | ||
| .assert_approx_eq::<FT>(&expected, Tolerance::default()); | ||
| } | ||
| #[test] | ||
| fn should_support_cat_with_multiple_empty_tensors() { | ||
| let device = Default::default(); | ||
| let tensor_1: TestTensor<2> = TestTensor::empty([2, 0], &device); | ||
| let tensor_2 = TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device); | ||
| let tensor_3: TestTensor<2> = TestTensor::empty([2, 0], &device); | ||
| let tensor_4 = TestTensor::<2>::from_data([[5.0], [6.0]], &device); | ||
| // Mix of empty and non-empty tensors | ||
| let output = TestTensor::cat(vec![tensor_1, tensor_2, tensor_3, tensor_4], 1); | ||
| let expected = TensorData::from([[1.0, 2.0, 5.0], [3.0, 4.0, 6.0]]); | ||
| output | ||
| .into_data() | ||
| .assert_approx_eq::<FT>(&expected, Tolerance::default()); | ||
| } | ||
| #[test] | ||
| fn should_support_cat_all_empty_tensors() { | ||
| let device = Default::default(); | ||
| let tensor_1: TestTensor<2> = TestTensor::empty([2, 0], &device); | ||
| let tensor_2: TestTensor<2> = TestTensor::empty([2, 0], &device); | ||
| // All empty tensors should produce an empty tensor | ||
| let output = TestTensor::cat(vec![tensor_1, tensor_2], 1); | ||
| assert_eq!(output.shape().dims, [2, 0]); | ||
| } | ||
| #[test] | ||
| fn should_support_cat_with_empty_tensor_int() { | ||
| let device = Default::default(); | ||
| let tensor_1 = TestTensorInt::<2>::from_data([[1, 2, 3]], &device); | ||
| let tensor_2: TestTensorInt<2> = TestTensorInt::empty([1, 0], &device); | ||
| let output = Tensor::cat(vec![tensor_1, tensor_2], 1); | ||
| output | ||
| .into_data() | ||
| .assert_eq(&TensorData::from([[1, 2, 3]]), false); | ||
| } | ||
| #[test] | ||
| fn should_support_cat_with_empty_tensor_bool() { | ||
| let device = Default::default(); | ||
| let tensor_1 = TestTensorBool::<2>::from_data([[true, false, true]], &device); | ||
| let tensor_2: TestTensorBool<2> = TestTensorBool::empty([1, 0], &device); | ||
| let output = Tensor::cat(vec![tensor_1, tensor_2], 1); | ||
| output | ||
| .into_data() | ||
| .assert_eq(&TensorData::from([[true, false, true]]), false); | ||
| } | ||
| } |
| #[burn_tensor_testgen::testgen(gather_scatter)] | ||
| mod tests { | ||
| use super::*; | ||
| use burn_tensor::{Tensor, TensorData}; | ||
| use burn_tensor::{IndexingUpdateOp, Tensor, TensorData}; | ||
@@ -135,3 +135,3 @@ #[test] | ||
| #[test] | ||
| fn should_scatter_1d() { | ||
| fn should_scatter_add_1d() { | ||
| let device = Default::default(); | ||
@@ -142,3 +142,3 @@ let tensor = TestTensor::<1>::from_floats([0.0, 0.0, 0.0], &device); | ||
| let output = tensor.scatter(0, indices, values); | ||
| let output = tensor.scatter(0, indices, values, IndexingUpdateOp::Add); | ||
@@ -151,3 +151,3 @@ output | ||
| #[test] | ||
| fn should_scatter_1d_int() { | ||
| fn should_scatter_add_1d_int() { | ||
| let device = Default::default(); | ||
@@ -158,3 +158,3 @@ let tensor = TestTensorInt::<1>::from_ints([0, 0, 0], &device); | ||
| let output = tensor.scatter(0, indices, values); | ||
| let output = tensor.scatter(0, indices, values, IndexingUpdateOp::Add); | ||
@@ -167,3 +167,3 @@ output | ||
| #[test] | ||
| fn should_scatter_2d_dim0() { | ||
| fn should_scatter_add_2d_dim0() { | ||
| let device = Default::default(); | ||
@@ -174,3 +174,3 @@ let tensor = TestTensor::<2>::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], &device); | ||
| let output = tensor.scatter(0, indices, values); | ||
| let output = tensor.scatter(0, indices, values, IndexingUpdateOp::Add); | ||
@@ -183,3 +183,3 @@ output | ||
| #[test] | ||
| fn should_scatter_2d_dim1() { | ||
| fn should_scatter_add_2d_dim1() { | ||
| let device = Default::default(); | ||
@@ -190,3 +190,3 @@ let tensor = TestTensor::<2>::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], &device); | ||
| let output = tensor.scatter(1, indices, values); | ||
| let output = tensor.scatter(1, indices, values, IndexingUpdateOp::Add); | ||
@@ -199,3 +199,3 @@ output | ||
| #[test] | ||
| fn should_scatter_3d_dim1() { | ||
| fn should_scatter_add_3d_dim1() { | ||
| let device = Default::default(); | ||
@@ -219,3 +219,3 @@ let tensor = TestTensor::<3>::from_floats( | ||
| let output = tensor.scatter(1, indices, values); | ||
| let output = tensor.scatter(1, indices, values, IndexingUpdateOp::Add); | ||
| let expected = TensorData::from([ | ||
@@ -230,3 +230,3 @@ [[15.0, 14.0, 33.0], [15.0, 20.0, 5.0]], | ||
| #[test] | ||
| fn should_scatter_2d_dim1_diff_shape() { | ||
| fn should_scatter_add_2d_dim1_diff_shape() { | ||
| let device = Default::default(); | ||
@@ -237,3 +237,3 @@ let tensor = TestTensor::<2>::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], &device); | ||
| let output = tensor.scatter(1, indices, values); | ||
| let output = tensor.scatter(1, indices, values, IndexingUpdateOp::Add); | ||
@@ -253,4 +253,4 @@ output | ||
| tensor.scatter(0, indices, values); | ||
| tensor.scatter(0, indices, values, IndexingUpdateOp::Add); | ||
| } | ||
| } |
+249
-34
@@ -7,2 +7,3 @@ #[burn_tensor_testgen::testgen(padding)] | ||
| backend::Backend, | ||
| ops::PadMode, | ||
| tests::{Float as _, Int as _}, | ||
@@ -12,7 +13,7 @@ }; | ||
| #[test] | ||
| fn padding_2d_test() { | ||
| fn padding_constant_2d_test() { | ||
| let unpadded_floats: [[f32; 3]; 2] = [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]; | ||
| let tensor = TestTensor::<2>::from(unpadded_floats); | ||
| let padded_tensor = tensor.pad((2, 2, 2, 2), FloatType::new(1.1)); | ||
| let padded_tensor = tensor.pad((2, 2, 2, 2), PadMode::Constant(1.1)); | ||
@@ -31,7 +32,7 @@ let expected = TensorData::from(as_type!(FloatType: [ | ||
| #[test] | ||
| fn padding_4d_test() { | ||
| fn padding_constant_4d_test() { | ||
| let unpadded_floats = [[[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]]]]; | ||
| let tensor = TestTensor::<4>::from(unpadded_floats); | ||
| let padded_tensor = tensor.pad((2, 2, 2, 2), FloatType::new(1.1)); | ||
| let padded_tensor = tensor.pad((2, 2, 2, 2), PadMode::Constant(1.1)); | ||
@@ -51,7 +52,7 @@ let expected = TensorData::from(as_type!(FloatType: [[[ | ||
| #[test] | ||
| fn padding_asymmetric_test() { | ||
| fn padding_constant_asymmetric_test() { | ||
| let unpadded_floats = [[[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]]]]; | ||
| let tensor = TestTensor::<4>::from(unpadded_floats); | ||
| let padded_tensor = tensor.pad((2, 1, 4, 3), FloatType::new(1.1)); | ||
| let padded_tensor = tensor.pad((2, 1, 4, 3), PadMode::Constant(1.1)); | ||
@@ -74,34 +75,248 @@ let expected = TensorData::from(as_type!(FloatType: [[[ | ||
| #[test] | ||
| fn padding_asymmetric_integer_test() { | ||
| let unpadded_ints = [[[[0, 1], [2, 3], [4, 5]]]]; | ||
| fn padding_reflect_2d_test() { | ||
| // Test reflect padding on a 2D tensor | ||
| // Input: [[1, 2, 3], [4, 5, 6]] | ||
| // With padding (1, 1, 1, 1): | ||
| // - Top: reflect row 1 -> [4, 5, 6] | ||
| // - Bottom: reflect row 0 -> [1, 2, 3] | ||
| // - Left: reflect col 1 | ||
| // - Right: reflect col 1 | ||
| let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); | ||
| let tensor = TestTensorInt::<4>::from(unpadded_ints); | ||
| let padded_tensor = tensor.pad((2, 1, 4, 3), IntType::new(6)); | ||
| let padded_tensor = tensor.pad((1, 1, 1, 1), PadMode::Reflect); | ||
| let padded_primitive_data_expected = [[[ | ||
| [6, 6, 6, 6, 6], | ||
| [6, 6, 6, 6, 6], | ||
| [6, 6, 6, 6, 6], | ||
| [6, 6, 6, 6, 6], | ||
| [6, 6, 0, 1, 6], | ||
| [6, 6, 2, 3, 6], | ||
| [6, 6, 4, 5, 6], | ||
| [6, 6, 6, 6, 6], | ||
| [6, 6, 6, 6, 6], | ||
| [6, 6, 6, 6, 6], | ||
| ]]]; | ||
| let expected = TensorData::from(as_type!(IntType: [[[ | ||
| [6, 6, 6, 6, 6], | ||
| [6, 6, 6, 6, 6], | ||
| [6, 6, 6, 6, 6], | ||
| [6, 6, 6, 6, 6], | ||
| [6, 6, 0, 1, 6], | ||
| [6, 6, 2, 3, 6], | ||
| [6, 6, 4, 5, 6], | ||
| [6, 6, 6, 6, 6], | ||
| [6, 6, 6, 6, 6], | ||
| [6, 6, 6, 6, 6], | ||
| ]]])); | ||
| // Expected: reflect excludes the edge value | ||
| // Before padding height: [[1,2,3], [4,5,6]] | ||
| // After top pad (reflect row at index 1): [[4,5,6], [1,2,3], [4,5,6]] | ||
| // After bottom pad (reflect row at index 1 from end): [[4,5,6], [1,2,3], [4,5,6], [1,2,3]] | ||
| // Then pad width similarly | ||
| let expected = TensorData::from(as_type!(FloatType: [ | ||
| [5.0, 4.0, 5.0, 6.0, 5.0], | ||
| [2.0, 1.0, 2.0, 3.0, 2.0], | ||
| [5.0, 4.0, 5.0, 6.0, 5.0], | ||
| [2.0, 1.0, 2.0, 3.0, 2.0], | ||
| ])); | ||
| padded_tensor.into_data().assert_eq(&expected, false); | ||
| } | ||
| #[test] | ||
| fn padding_reflect_width_only_test() { | ||
| // Test reflect padding on width dimension only | ||
| let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0, 4.0]]); | ||
| let padded_tensor = tensor.pad((2, 2, 0, 0), PadMode::Reflect); | ||
| // Input: [1, 2, 3, 4] | ||
| // Reflect left 2: take indices [1, 2] = [2, 3], flip = [3, 2] | ||
| // Reflect right 2: take indices [1, 2] from end = [2, 3], flip = [3, 2] | ||
| // Result: [3, 2, 1, 2, 3, 4, 3, 2] | ||
| let expected = | ||
| TensorData::from(as_type!(FloatType: [[3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 2.0]])); | ||
| padded_tensor.into_data().assert_eq(&expected, false); | ||
| } | ||
| #[test] | ||
| fn padding_reflect_4d_test() { | ||
| // Test reflect padding on 4D tensor (common for images: NCHW) | ||
| let tensor = TestTensor::<4>::from([[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]]); | ||
| let padded_tensor = tensor.pad((1, 1, 1, 1), PadMode::Reflect); | ||
| let expected = TensorData::from(as_type!(FloatType: [[[[ | ||
| 5.0, 4.0, 5.0, 6.0, 5.0], | ||
| [2.0, 1.0, 2.0, 3.0, 2.0], | ||
| [5.0, 4.0, 5.0, 6.0, 5.0], | ||
| [8.0, 7.0, 8.0, 9.0, 8.0], | ||
| [5.0, 4.0, 5.0, 6.0, 5.0 | ||
| ]]]])); | ||
| padded_tensor.into_data().assert_eq(&expected, false); | ||
| } | ||
| #[test] | ||
| fn padding_edge_2d_test() { | ||
| // Test edge padding on a 2D tensor | ||
| let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); | ||
| let padded_tensor = tensor.pad((1, 1, 1, 1), PadMode::Edge); | ||
| // Edge padding replicates the boundary values | ||
| let expected = TensorData::from(as_type!(FloatType: [ | ||
| [1.0, 1.0, 2.0, 3.0, 3.0], | ||
| [1.0, 1.0, 2.0, 3.0, 3.0], | ||
| [4.0, 4.0, 5.0, 6.0, 6.0], | ||
| [4.0, 4.0, 5.0, 6.0, 6.0], | ||
| ])); | ||
| padded_tensor.into_data().assert_eq(&expected, false); | ||
| } | ||
| #[test] | ||
| fn padding_edge_width_only_test() { | ||
| // Test edge padding on width dimension only | ||
| let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0, 4.0]]); | ||
| let padded_tensor = tensor.pad((2, 3, 0, 0), PadMode::Edge); | ||
| // Input: [1, 2, 3, 4] | ||
| // Left 2: [1, 1] | ||
| // Right 3: [4, 4, 4] | ||
| // Result: [1, 1, 1, 2, 3, 4, 4, 4, 4] | ||
| let expected = | ||
| TensorData::from(as_type!(FloatType: [[1.0, 1.0, 1.0, 2.0, 3.0, 4.0, 4.0, 4.0, 4.0]])); | ||
| padded_tensor.into_data().assert_eq(&expected, false); | ||
| } | ||
| #[test] | ||
| fn padding_edge_4d_test() { | ||
| // Test edge padding on 4D tensor | ||
| let tensor = TestTensor::<4>::from([[[[1.0, 2.0], [3.0, 4.0]]]]); | ||
| let padded_tensor = tensor.pad((1, 1, 1, 1), PadMode::Edge); | ||
| let expected = TensorData::from(as_type!(FloatType: [[[[ | ||
| 1.0, 1.0, 2.0, 2.0], | ||
| [1.0, 1.0, 2.0, 2.0], | ||
| [3.0, 3.0, 4.0, 4.0], | ||
| [3.0, 3.0, 4.0, 4.0 | ||
| ]]]])); | ||
| padded_tensor.into_data().assert_eq(&expected, false); | ||
| } | ||
| #[test] | ||
| fn padding_constant_default_test() { | ||
| // Test default PadMode (Constant with 0.0) | ||
| let tensor = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]); | ||
| let padded_tensor = tensor.pad((1, 1, 1, 1), PadMode::default()); | ||
| let expected = TensorData::from(as_type!(FloatType: [ | ||
| [0.0, 0.0, 0.0, 0.0], | ||
| [0.0, 1.0, 2.0, 0.0], | ||
| [0.0, 3.0, 4.0, 0.0], | ||
| [0.0, 0.0, 0.0, 0.0], | ||
| ])); | ||
| padded_tensor.into_data().assert_eq(&expected, false); | ||
| } | ||
| #[test] | ||
| fn padding_reflect_max_valid_test() { | ||
| // Test reflect padding at maximum valid size (dim_size - 1) | ||
| // For a 4-element dimension, max valid padding is 3 | ||
| let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0, 4.0]]); | ||
| // Padding of 3 on left is valid for width=4 (3 < 4) | ||
| let padded_tensor = tensor.pad((3, 3, 0, 0), PadMode::Reflect); | ||
| // Input: [1, 2, 3, 4] | ||
| // Reflect left 3: take indices [1, 2, 3] = [2, 3, 4], flip = [4, 3, 2] | ||
| // Reflect right 3: take indices [0, 1, 2] = [1, 2, 3], flip = [3, 2, 1] | ||
| // Result: [4, 3, 2, 1, 2, 3, 4, 3, 2, 1] | ||
| let expected = TensorData::from( | ||
| as_type!(FloatType: [[4.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 2.0, 1.0]]), | ||
| ); | ||
| padded_tensor.into_data().assert_eq(&expected, false); | ||
| } | ||
| #[test] | ||
| fn padding_reflect_asymmetric_test() { | ||
| // Test asymmetric reflect padding | ||
| let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]); | ||
| // Asymmetric padding: left=2, right=1, top=1, bottom=2 | ||
| let padded_tensor = tensor.pad((2, 1, 1, 2), PadMode::Reflect); | ||
| let expected = TensorData::from(as_type!(FloatType: [ | ||
| [6.0, 5.0, 4.0, 5.0, 6.0, 5.0], | ||
| [3.0, 2.0, 1.0, 2.0, 3.0, 2.0], | ||
| [6.0, 5.0, 4.0, 5.0, 6.0, 5.0], | ||
| [9.0, 8.0, 7.0, 8.0, 9.0, 8.0], | ||
| [6.0, 5.0, 4.0, 5.0, 6.0, 5.0], | ||
| [3.0, 2.0, 1.0, 2.0, 3.0, 2.0], | ||
| ])); | ||
| padded_tensor.into_data().assert_eq(&expected, false); | ||
| } | ||
| #[test] | ||
| #[should_panic(expected = "Reflect padding")] | ||
| fn padding_reflect_exceeds_dimension_test() { | ||
| // Test that reflect padding panics when padding >= dim_size | ||
| let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0]]); | ||
| // Padding of 3 on width=3 should panic (3 >= 3, need padding < dim_size) | ||
| let _ = tensor.pad((3, 0, 0, 0), PadMode::Reflect); | ||
| } | ||
| #[test] | ||
| fn padding_edge_asymmetric_test() { | ||
| // Test asymmetric edge padding | ||
| let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); | ||
| // Asymmetric padding: left=2, right=1, top=3, bottom=1 | ||
| let padded_tensor = tensor.pad((2, 1, 3, 1), PadMode::Edge); | ||
| let expected = TensorData::from(as_type!(FloatType: [ | ||
| [1.0, 1.0, 1.0, 2.0, 3.0, 3.0], | ||
| [1.0, 1.0, 1.0, 2.0, 3.0, 3.0], | ||
| [1.0, 1.0, 1.0, 2.0, 3.0, 3.0], | ||
| [1.0, 1.0, 1.0, 2.0, 3.0, 3.0], | ||
| [4.0, 4.0, 4.0, 5.0, 6.0, 6.0], | ||
| [4.0, 4.0, 4.0, 5.0, 6.0, 6.0], | ||
| ])); | ||
| padded_tensor.into_data().assert_eq(&expected, false); | ||
| } | ||
| #[test] | ||
| fn padding_zero_padding_test() { | ||
| // Test that zero padding returns the original tensor unchanged | ||
| let tensor = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]); | ||
| let padded_constant = tensor.clone().pad((0, 0, 0, 0), PadMode::Constant(5.0)); | ||
| let padded_reflect = tensor.clone().pad((0, 0, 0, 0), PadMode::Reflect); | ||
| let padded_edge = tensor.clone().pad((0, 0, 0, 0), PadMode::Edge); | ||
| let expected = TensorData::from(as_type!(FloatType: [[1.0, 2.0], [3.0, 4.0]])); | ||
| padded_constant.into_data().assert_eq(&expected, false); | ||
| padded_reflect.into_data().assert_eq(&expected, false); | ||
| padded_edge.into_data().assert_eq(&expected, false); | ||
| } | ||
| #[test] | ||
| fn padding_empty_tensor_constant_test() { | ||
| // Test constant padding on an empty tensor (zero-sized dimension) | ||
| // This should work - creates a tensor filled with the constant value | ||
| let tensor: TestTensor<2> = TestTensor::empty([0, 3], &Default::default()); | ||
| // Padding an empty height dimension with constant should create a tensor of just padding | ||
| let padded = tensor.pad((0, 0, 2, 2), PadMode::Constant(1.0)); | ||
| // Result should be 4x3 (0 + 2 + 2 = 4 rows) | ||
| assert_eq!(padded.dims(), [4, 3]); | ||
| let expected = TensorData::from(as_type!(FloatType: [ | ||
| [1.0, 1.0, 1.0], | ||
| [1.0, 1.0, 1.0], | ||
| [1.0, 1.0, 1.0], | ||
| [1.0, 1.0, 1.0], | ||
| ])); | ||
| padded.into_data().assert_eq(&expected, false); | ||
| } | ||
| #[test] | ||
| #[should_panic(expected = "edge padding")] | ||
| fn padding_empty_tensor_edge_panics_test() { | ||
| // Test that edge padding panics on empty tensor | ||
| let tensor: TestTensor<2> = TestTensor::empty([0, 3], &Default::default()); | ||
| // Edge padding on zero-sized dimension should panic | ||
| let _ = tensor.pad((0, 0, 1, 1), PadMode::Edge); | ||
| } | ||
| #[test] | ||
| #[should_panic(expected = "Reflect padding")] | ||
| fn padding_empty_tensor_reflect_panics_test() { | ||
| // Test that reflect padding panics on empty tensor | ||
| let tensor: TestTensor<2> = TestTensor::empty([0, 3], &Default::default()); | ||
| // Reflect padding on zero-sized dimension should panic | ||
| let _ = tensor.pad((0, 0, 1, 1), PadMode::Reflect); | ||
| } | ||
| } |
+63
-55
| #[burn_tensor_testgen::testgen(select)] | ||
| mod tests { | ||
| use super::*; | ||
| use burn_tensor::{Tensor, TensorData, backend::Backend}; | ||
| use burn_tensor::{IndexingUpdateOp, Tensor, TensorData, backend::Backend}; | ||
@@ -85,3 +85,3 @@ #[test] | ||
| #[test] | ||
| fn should_select_assign_1d() { | ||
| fn should_select_add_1d() { | ||
| let device = Default::default(); | ||
@@ -92,3 +92,3 @@ let tensor = TestTensor::<1>::from_data([0.0, 1.0, 2.0], &device); | ||
| let output = tensor.select_assign(0, indices, values); | ||
| let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add); | ||
| let expected = TensorData::from([3.0, 12.0, 3.0]); | ||
@@ -100,3 +100,3 @@ | ||
| #[test] | ||
| fn should_select_assign_1d_int() { | ||
| fn should_select_add_1d_int() { | ||
| let device = Default::default(); | ||
@@ -107,3 +107,3 @@ let tensor = TestTensorInt::<1>::from_data([7, 8, 9], &device); | ||
| let output = tensor.select_assign(0, indices, values); | ||
| let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add); | ||
| let expected = TensorData::from([10, 19, 10]); | ||
@@ -115,3 +115,3 @@ | ||
| #[test] | ||
| fn should_select_assign_2d_dim0() { | ||
| fn should_select_add_2d_dim0() { | ||
| let device = Default::default(); | ||
@@ -122,3 +122,3 @@ let tensor = TestTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device); | ||
| let output = tensor.select_assign(0, indices, values); | ||
| let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add); | ||
| let expected = TensorData::from([[4.0, 6.0, 8.0], [4.0, 6.0, 8.0]]); | ||
@@ -130,3 +130,3 @@ | ||
| #[test] | ||
| fn should_select_assign_2d_dim1() { | ||
| fn should_select_add_2d_dim1() { | ||
| let device = Default::default(); | ||
@@ -137,3 +137,3 @@ let tensor = TestTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device); | ||
| let output = tensor.select_assign(1, indices, values); | ||
| let output = tensor.select_assign(1, indices, values, IndexingUpdateOp::Add); | ||
| let expected = TensorData::from([[2.0, 2.0, 5.0], [8.0, 8.0, 11.0]]); | ||
@@ -177,3 +177,3 @@ | ||
| #[should_panic] | ||
| fn should_panic_select_assign_invalid_num_indices() { | ||
| fn should_panic_select_add_invalid_num_indices() { | ||
| let device = Default::default(); | ||
@@ -184,3 +184,3 @@ let tensor = TestTensorInt::<1>::from_data([0; 12], &device); | ||
| tensor.select_assign(0, indices, values); | ||
| tensor.select_assign(0, indices, values, IndexingUpdateOp::Add); | ||
| } | ||
@@ -216,4 +216,4 @@ | ||
| #[test] | ||
| fn should_select_assign_bool_tensor() { | ||
| // Test that select_assign works for boolean tensors | ||
| fn should_select_add_bool_tensor() { | ||
| // Test that select_add works for boolean tensors | ||
| let device = Default::default(); | ||
@@ -224,4 +224,4 @@ let tensor = TestTensorBool::<1>::from_data([true, false, true], &device); | ||
| let output = tensor.select_assign(0, indices, values); | ||
| // Note: select_assign uses sum reduction, so: | ||
| let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add); | ||
| // Note: select_add uses sum reduction, so: | ||
| // index 0: true OR false = true | ||
@@ -236,3 +236,3 @@ // index 2: true OR false = true | ||
| #[test] | ||
| fn should_select_assign_bool_overlapping_indices() { | ||
| fn should_select_add_bool_overlapping_indices() { | ||
| // Test accumulation behavior with overlapping indices | ||
@@ -244,3 +244,3 @@ let device = Default::default(); | ||
| let output = tensor.select_assign(0, indices, values); | ||
| let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add); | ||
| // Index 0: false OR true OR false = true | ||
@@ -253,3 +253,3 @@ let expected = TensorData::from([true, true]); | ||
| #[test] | ||
| fn should_select_assign_bool_false_to_true_case() { | ||
| fn should_select_add_bool_false_to_true_case() { | ||
| // Test false OR true = true | ||
@@ -261,3 +261,3 @@ let device = Default::default(); | ||
| let output = tensor.select_assign(0, indices, values); | ||
| let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add); | ||
| let expected = TensorData::from([true]); | ||
@@ -269,3 +269,3 @@ | ||
| #[test] | ||
| fn should_select_assign_bool_true_or_true_accumulation() { | ||
| fn should_select_add_bool_true_or_true_accumulation() { | ||
| // Test multiple true accumulations | ||
@@ -277,3 +277,3 @@ let device = Default::default(); | ||
| let output = tensor.select_assign(0, indices, values); | ||
| let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add); | ||
| let expected = TensorData::from([true, false]); | ||
@@ -292,5 +292,6 @@ | ||
| let optimized_result = tensor | ||
| .clone() | ||
| .select_assign(0, indices.clone(), values.clone()); | ||
| let optimized_result = | ||
| tensor | ||
| .clone() | ||
| .select_assign(0, indices.clone(), values.clone(), IndexingUpdateOp::Add); | ||
@@ -300,3 +301,3 @@ // Manual default implementation logic | ||
| let int_values = values.int(); | ||
| let assigned = int_tensor.select_assign(0, indices, int_values); | ||
| let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add); | ||
| let default_result = assigned.greater_elem(0); | ||
@@ -310,3 +311,3 @@ | ||
| #[test] | ||
| fn should_select_assign_bool_overlapping_indices_vs_default() { | ||
| fn should_select_add_bool_overlapping_indices_vs_default() { | ||
| // Test overlapping indices against default implementation | ||
@@ -318,9 +319,10 @@ let device = Default::default(); | ||
| let optimized_result = tensor | ||
| .clone() | ||
| .select_assign(0, indices.clone(), values.clone()); | ||
| let optimized_result = | ||
| tensor | ||
| .clone() | ||
| .select_assign(0, indices.clone(), values.clone(), IndexingUpdateOp::Add); | ||
| let int_tensor = tensor.int(); | ||
| let int_values = values.int(); | ||
| let assigned = int_tensor.select_assign(0, indices, int_values); | ||
| let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add); | ||
| let default_result = assigned.greater_elem(0); | ||
@@ -334,3 +336,3 @@ | ||
| #[test] | ||
| fn should_select_assign_bool_true_or_true_accumulation_vs_default() { | ||
| fn should_select_add_bool_true_or_true_accumulation_vs_default() { | ||
| // Test multiple true accumulations against default implementation | ||
@@ -342,9 +344,10 @@ let device = Default::default(); | ||
| let optimized_result = tensor | ||
| .clone() | ||
| .select_assign(0, indices.clone(), values.clone()); | ||
| let optimized_result = | ||
| tensor | ||
| .clone() | ||
| .select_assign(0, indices.clone(), values.clone(), IndexingUpdateOp::Add); | ||
| let int_tensor = tensor.int(); | ||
| let int_values = values.int(); | ||
| let assigned = int_tensor.select_assign(0, indices, int_values); | ||
| let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add); | ||
| let default_result = assigned.greater_elem(0); | ||
@@ -358,3 +361,3 @@ | ||
| #[test] | ||
| fn should_select_assign_bool_false_to_true_case_vs_default() { | ||
| fn should_select_add_bool_false_to_true_case_vs_default() { | ||
| // Test false OR true case against default implementation | ||
@@ -368,9 +371,10 @@ use burn_tensor::backend::Backend; | ||
| let optimized_result = tensor | ||
| .clone() | ||
| .select_assign(0, indices.clone(), values.clone()); | ||
| let optimized_result = | ||
| tensor | ||
| .clone() | ||
| .select_assign(0, indices.clone(), values.clone(), IndexingUpdateOp::Add); | ||
| let int_tensor = tensor.int(); | ||
| let int_values = values.int(); | ||
| let assigned = int_tensor.select_assign(0, indices, int_values); | ||
| let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add); | ||
| let default_result = assigned.greater_elem(0); | ||
@@ -384,3 +388,3 @@ | ||
| #[test] | ||
| fn should_select_assign_bool_tensor_vs_default() { | ||
| fn should_select_add_bool_tensor_vs_default() { | ||
| // Test existing basic case against default implementation | ||
@@ -394,9 +398,10 @@ use burn_tensor::backend::Backend; | ||
| let optimized_result = tensor | ||
| .clone() | ||
| .select_assign(0, indices.clone(), values.clone()); | ||
| let optimized_result = | ||
| tensor | ||
| .clone() | ||
| .select_assign(0, indices.clone(), values.clone(), IndexingUpdateOp::Add); | ||
| let int_tensor = tensor.int(); | ||
| let int_values = values.int(); | ||
| let assigned = int_tensor.select_assign(0, indices, int_values); | ||
| let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add); | ||
| let default_result = assigned.greater_elem(0); | ||
@@ -418,3 +423,3 @@ | ||
| let output = tensor.select_assign(0, indices, values); | ||
| let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add); | ||
| let replacement_expected = TensorData::from([false]); | ||
@@ -437,3 +442,3 @@ | ||
| let int_values = values.int(); | ||
| let assigned = int_tensor.select_assign(0, indices, int_values); | ||
| let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add); | ||
| let default_result = assigned.greater_elem(0); | ||
@@ -465,4 +470,4 @@ let replacement_expected = TensorData::from([false]); | ||
| #[test] | ||
| fn should_select_assign_with_negative_dim_2d() { | ||
| // Test select_assign with negative dimension on 2D tensor | ||
| fn should_select_add_with_negative_dim_2d() { | ||
| // Test select_add with negative dimension on 2D tensor | ||
| let device = Default::default(); | ||
@@ -474,6 +479,9 @@ let tensor = TestTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device); | ||
| // Using -1 should refer to the last dimension (dim 1) | ||
| let output_neg = tensor | ||
| .clone() | ||
| .select_assign(-1, indices.clone(), values.clone()); | ||
| let output_pos = tensor.select_assign(1, indices, values); | ||
| let output_neg = tensor.clone().select_assign( | ||
| -1, | ||
| indices.clone(), | ||
| values.clone(), | ||
| IndexingUpdateOp::Add, | ||
| ); | ||
| let output_pos = tensor.select_assign(1, indices, values, IndexingUpdateOp::Add); | ||
@@ -498,3 +506,3 @@ output_neg | ||
| #[should_panic] | ||
| fn should_panic_select_assign_negative_dim_out_of_bounds() { | ||
| fn should_panic_select_add_negative_dim_out_of_bounds() { | ||
| let device = Default::default(); | ||
@@ -506,4 +514,4 @@ let tensor = TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device); | ||
| // This should panic because -3 is out of bounds for a 2D tensor | ||
| tensor.select_assign(-3, indices, values); | ||
| tensor.select_assign(-3, indices, values, IndexingUpdateOp::Add); | ||
| } | ||
| } |
@@ -354,2 +354,41 @@ #[burn_tensor_testgen::testgen(slice_assign)] | ||
| } | ||
| #[test] | ||
| fn should_support_slice_assign_empty_range() { | ||
| let device = Default::default(); | ||
| let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device); | ||
| let values: TestTensor<2> = TestTensor::empty([2, 0], &device); | ||
| // Empty slice assignment (start == end) should be a no-op | ||
| let output = tensor.clone().slice_assign([0..2, 1..1], values); | ||
| let expected = TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); | ||
| output.into_data().assert_eq(&expected, false); | ||
| } | ||
| #[test] | ||
| fn should_support_slice_assign_empty_range_1d() { | ||
| let device = Default::default(); | ||
| let tensor = TestTensor::<1>::from_data([1.0, 2.0, 3.0, 4.0, 5.0], &device); | ||
| let values: TestTensor<1> = TestTensor::empty([0], &device); | ||
| // Empty slice assignment should return tensor unchanged | ||
| let output = tensor.clone().slice_assign([2..2], values); | ||
| let expected = TensorData::from([1.0, 2.0, 3.0, 4.0, 5.0]); | ||
| output.into_data().assert_eq(&expected, false); | ||
| } | ||
| #[test] | ||
| fn should_support_slice_assign_empty_range_int() { | ||
| let device = Default::default(); | ||
| let tensor = TestTensorInt::<1>::from_data([1, 2, 3, 4, 5], &device); | ||
| let values: TestTensorInt<1> = TestTensorInt::empty([0], &device); | ||
| // Empty slice assignment for int tensor | ||
| let output = tensor.clone().slice_assign([3..3], values); | ||
| let expected = TensorData::from([1i32, 2, 3, 4, 5]); | ||
| output.into_data().assert_eq(&expected, false); | ||
| } | ||
| } |
@@ -94,3 +94,3 @@ #[burn_tensor_testgen::testgen(q_gather_scatter)] | ||
| let output = tensor.scatter(0, indices, values); | ||
| let output = tensor.scatter_add(0, indices, values); | ||
@@ -110,3 +110,3 @@ // Precision 1 to approximate de/quantization errors | ||
| let output = tensor.scatter(0, indices, values); | ||
| let output = tensor.scatter_add(0, indices, values); | ||
@@ -126,3 +126,3 @@ // Precision 1 to approximate de/quantization errors | ||
| let output = tensor.scatter(1, indices, values); | ||
| let output = tensor.scatter_add(1, indices, values); | ||
@@ -151,3 +151,3 @@ // Precision 1 to approximate de/quantization errors | ||
| let output = tensor.scatter(1, indices, values); | ||
| let output = tensor.scatter_add(1, indices, values); | ||
| let expected = TensorData::from([ | ||
@@ -171,3 +171,3 @@ [[15.0, 14.0, 33.0], [15.0, 20.0, 5.0]], | ||
| let output = tensor.scatter(1, indices, values); | ||
| let output = tensor.scatter_add(1, indices, values); | ||
@@ -188,4 +188,4 @@ // Precision 1 to approximate de/quantization errors | ||
| tensor.scatter(0, indices, values); | ||
| tensor.scatter_add(0, indices, values); | ||
| } | ||
| } |
| use crate::{Device, ElementConversion, Shape, TensorData, backend::Backend, ops::IntTensor}; | ||
| use alloc::vec::Vec; | ||
| /// Compute the indices of the elements that are non-zero, grouped by element. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `data` - The input tensor data. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A 2D tensor containing the indices of all non-zero elements of the given tensor. | ||
| /// Each row contains the indices of a non-zero element. | ||
| /// | ||
| /// # Remarks | ||
| /// | ||
| /// This is a fallback solution that used only when the backend doesn't have the corresponding implementation. | ||
| /// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved | ||
| /// by static dispatch. It is not designed for direct usage by users, and not recommended to import | ||
| /// or use this function directly. | ||
| pub fn argwhere_data<B: Backend>(data: TensorData, device: &Device<B>) -> IntTensor<B> { | ||
| let dims = &data.shape; | ||
| let ndims = dims.len(); | ||
| let count_nonzero = data.iter::<bool>().filter(|&v| v).count(); | ||
| /// Converts a flat index into a vector of indices for the specified tensor shape | ||
| fn unravel_index<B: Backend>(index: usize, shape: &[usize]) -> Vec<B::IntElem> { | ||
| shape | ||
| .iter() | ||
| .rev() | ||
| .scan(index, |i, size| { | ||
| let dim_idx = *i % size; | ||
| *i /= size; | ||
| Some((dim_idx as i64).elem()) | ||
| }) | ||
| .collect::<Vec<_>>() | ||
| .into_iter() | ||
| .rev() | ||
| .collect() | ||
| } | ||
| let indices = data | ||
| .iter::<bool>() | ||
| .enumerate() | ||
| .filter_map(|(index, v)| if v { Some(index) } else { None }) | ||
| .map(|index| unravel_index::<B>(index, dims)) | ||
| .collect::<Vec<_>>() | ||
| .concat(); | ||
| B::int_from_data( | ||
| TensorData::new(indices, Shape::new([count_nonzero, ndims])), | ||
| device, | ||
| ) | ||
| } |
| use crate::{DType, Shape, backend::Backend}; | ||
| /// A type-level representation of the kind of a float tensor | ||
| #[derive(Clone, Debug)] | ||
| pub struct Float; | ||
| /// A type-level representation of the kind of a int tensor. | ||
| #[derive(Clone, Debug)] | ||
| pub struct Int; | ||
| /// A type-level representation of the kind of a bool tensor. | ||
| #[derive(Clone, Debug)] | ||
| pub struct Bool; | ||
| #[derive(Debug, Clone)] | ||
| /// A primitive tensor representation. | ||
| pub enum TensorPrimitive<B: Backend> { | ||
| /// Float tensor primitive. | ||
| Float(B::FloatTensorPrimitive), | ||
| /// Quantized float tensor primitive. | ||
| QFloat(B::QuantizedTensorPrimitive), | ||
| } | ||
| impl<B: Backend> TensorPrimitive<B> { | ||
| /// Returns the full tensor representation. | ||
| pub fn tensor(self) -> B::FloatTensorPrimitive { | ||
| match self { | ||
| Self::QFloat(tensor) => B::dequantize(tensor), | ||
| Self::Float(tensor) => tensor, | ||
| } | ||
| } | ||
| } | ||
| impl<B: Backend> TensorMetadata for TensorPrimitive<B> { | ||
| fn dtype(&self) -> DType { | ||
| match self { | ||
| TensorPrimitive::Float(tensor) => tensor.dtype(), | ||
| TensorPrimitive::QFloat(tensor) => tensor.dtype(), | ||
| } | ||
| } | ||
| fn shape(&self) -> Shape { | ||
| match self { | ||
| TensorPrimitive::Float(tensor) => tensor.shape(), | ||
| TensorPrimitive::QFloat(tensor) => tensor.shape(), | ||
| } | ||
| } | ||
| fn rank(&self) -> usize { | ||
| match self { | ||
| TensorPrimitive::Float(tensor) => tensor.rank(), | ||
| TensorPrimitive::QFloat(tensor) => tensor.rank(), | ||
| } | ||
| } | ||
| } | ||
| /// Tensor metadata trait for tensor primitive. | ||
| pub trait TensorMetadata: Clone + Send + Sync + core::fmt::Debug { | ||
| /// The dtype of the tensor. | ||
| fn dtype(&self) -> DType; | ||
| /// The shape of the tensor. | ||
| fn shape(&self) -> Shape; | ||
| /// The number of dimensions of the tensor. | ||
| fn rank(&self) -> usize { | ||
| self.shape().num_dims() | ||
| } | ||
| } | ||
| /// A type-level representation of the kind of a tensor. | ||
| /// Metadata access is lazy. | ||
| pub trait TensorKind<B: Backend>: Clone + core::fmt::Debug { | ||
| /// The primitive type of the tensor. | ||
| type Primitive: TensorMetadata; | ||
| /// The name of the tensor kind. | ||
| fn name() -> &'static str; | ||
| } | ||
| impl<B: Backend> TensorKind<B> for Float { | ||
| type Primitive = TensorPrimitive<B>; | ||
| fn name() -> &'static str { | ||
| "Float" | ||
| } | ||
| } | ||
| impl<B: Backend> TensorKind<B> for Int { | ||
| type Primitive = B::IntTensorPrimitive; | ||
| fn name() -> &'static str { | ||
| "Int" | ||
| } | ||
| } | ||
| impl<B: Backend> TensorKind<B> for Bool { | ||
| type Primitive = B::BoolTensorPrimitive; | ||
| fn name() -> &'static str { | ||
| "Bool" | ||
| } | ||
| } |
| use core::cmp::Ordering; | ||
| use crate::{ | ||
| BasicOps, Device, Element, ElementComparison, ElementConversion, TensorData, TensorKind, | ||
| backend::Backend, | ||
| ops::{IntElem, IntTensor}, | ||
| }; | ||
| use alloc::{vec, vec::Vec}; | ||
| use burn_std::reader::try_read_sync; | ||
| /// Sort the elements of the input `tensor` by value along a given dimension. | ||
| /// | ||
| /// This sort is unstable (i.e., may reorder equal elements). | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The input tensor. | ||
| /// * `dim` - The axis along which to sort. | ||
| /// * `descending` - The sorting order. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as the input tensor, where the elements are sorted by value. | ||
| /// | ||
| /// # Remarks | ||
| /// | ||
| /// This is a fallback solution that used only when the backend doesn't have the corresponding implementation. | ||
| /// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved | ||
| /// by static dispatch. It is not designed for direct usage by users, and not recommended to import | ||
| /// or use this function directly. | ||
| pub fn sort<B: Backend, K: TensorKind<B> + BasicOps<B>>( | ||
| tensor: K::Primitive, | ||
| dim: usize, | ||
| descending: bool, | ||
| ) -> K::Primitive | ||
| where | ||
| <K as BasicOps<B>>::Elem: Element, | ||
| { | ||
| let device = K::device(&tensor); | ||
| let msg = "Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation."; | ||
| let data = try_read_sync(K::into_data_async(tensor)) | ||
| .expect(msg) | ||
| .expect(msg); | ||
| sort_data::<B, K>(data, dim, &device, descending) | ||
| } | ||
| pub fn sort_data<B: Backend, K: TensorKind<B> + BasicOps<B>>( | ||
| mut data: TensorData, | ||
| dim: usize, | ||
| device: &Device<B>, | ||
| descending: bool, | ||
| ) -> K::Primitive | ||
| where | ||
| <K as BasicOps<B>>::Elem: Element, | ||
| { | ||
| let dims = data.shape.clone(); | ||
| let data_slice = data.as_mut_slice().unwrap(); | ||
| if dims.len() == 1 { | ||
| // 1D sort | ||
| data_slice.sort_unstable_by(|&a, &b| compare(&a, &b, descending)); | ||
| } else { | ||
| sort_slice::<B, K>(data_slice, &dims, dim, None, false, descending); | ||
| } | ||
| K::from_data(data, device) | ||
| } | ||
| /// Sort the elements of the input `tensor` by value along a given dimension. | ||
| /// | ||
| /// This sort is unstable (i.e., may reorder equal elements). | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The input tensor. | ||
| /// * `dim` - The axis along which to sort. | ||
| /// * `descending` - The sorting order. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as the input tensor and corresponding indices, where | ||
| /// the elements are sorted by value and the indices map back to the original input tensor. | ||
| /// | ||
| /// # Remarks | ||
| /// | ||
| /// This is a fallback solution that used only when the backend doesn't have the corresponding implementation. | ||
| /// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved | ||
| /// by static dispatch. It is not designed for direct usage by users, and not recommended to import | ||
| /// or use this function directly. | ||
| pub fn sort_with_indices<B: Backend, K: TensorKind<B> + BasicOps<B>>( | ||
| tensor: K::Primitive, | ||
| dim: usize, | ||
| descending: bool, | ||
| ) -> (K::Primitive, IntTensor<B>) | ||
| where | ||
| <K as BasicOps<B>>::Elem: Element, | ||
| { | ||
| let device = K::device(&tensor); | ||
| let msg = "Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation."; | ||
| let data = try_read_sync(K::into_data_async(tensor)) | ||
| .expect(msg) | ||
| .expect(msg); | ||
| sort_data_with_indices::<B, K>(data, dim, &device, descending) | ||
| } | ||
| fn sort_data_with_indices<B: Backend, K: TensorKind<B> + BasicOps<B>>( | ||
| mut data: TensorData, | ||
| dim: usize, | ||
| device: &Device<B>, | ||
| descending: bool, | ||
| ) -> (K::Primitive, IntTensor<B>) | ||
| where | ||
| <K as BasicOps<B>>::Elem: Element, | ||
| { | ||
| let dims = data.shape.clone(); | ||
| let mut indices_data = dim_indices::<B>(&dims, dim); | ||
| let data_slice = data.as_mut_slice().unwrap(); | ||
| if dims.len() == 1 { | ||
| // 1D sort | ||
| indices_data.sort_unstable_by(|&a, &b| { | ||
| compare( | ||
| &data_slice[a.elem::<i64>() as usize], | ||
| &data_slice[b.elem::<i64>() as usize], | ||
| descending, | ||
| ) | ||
| }); | ||
| // Permute data in-place by the sorted indices | ||
| let mut indices = indices_data | ||
| .clone() | ||
| .iter() | ||
| .map(|i| i.elem::<i64>() as usize) | ||
| .collect::<Vec<_>>(); | ||
| for idx in 0..indices.len() { | ||
| if indices[idx] != idx { | ||
| let mut current_idx = idx; | ||
| loop { | ||
| let target_idx = indices[current_idx]; | ||
| indices[current_idx] = current_idx; | ||
| if indices[target_idx] == target_idx { | ||
| // correct position | ||
| break; | ||
| } | ||
| // Permute data by indices | ||
| data_slice.swap(current_idx, target_idx); | ||
| current_idx = target_idx; | ||
| } | ||
| } | ||
| } | ||
| } else { | ||
| sort_slice::<B, K>( | ||
| data_slice, | ||
| &dims, | ||
| dim, | ||
| Some(&mut indices_data), | ||
| true, | ||
| descending, | ||
| ); | ||
| } | ||
| let shape = data.shape.clone(); | ||
| ( | ||
| K::from_data(data, device), | ||
| B::int_from_data(TensorData::new(indices_data, shape), device), | ||
| ) | ||
| } | ||
| /// Returns the indices that sort the elements of the input `tensor` along a given dimension. | ||
| /// | ||
| /// This sort is unstable (i.e., may reorder equal elements). | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The input tensor. | ||
| /// * `dim` - The axis along which to sort. | ||
| /// * `descending` - The sorting order. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as the input tensor the indices map back to the original input tensor. | ||
| /// | ||
| /// # Remarks | ||
| /// | ||
| /// This is a fallback solution that used only when the backend doesn't have the corresponding implementation. | ||
| /// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved | ||
| /// by static dispatch. It is not designed for direct usage by users, and not recommended to import | ||
| /// or use this function directly. | ||
| pub fn argsort<B: Backend, K: TensorKind<B> + BasicOps<B>>( | ||
| tensor: K::Primitive, | ||
| dim: usize, | ||
| descending: bool, | ||
| ) -> IntTensor<B> | ||
| where | ||
| <K as BasicOps<B>>::Elem: Element, | ||
| { | ||
| let device = K::device(&tensor); | ||
| let msg = "Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation."; | ||
| let data = try_read_sync(K::into_data_async(tensor)) | ||
| .expect(msg) | ||
| .expect(msg); | ||
| argsort_data::<B, K>(data, dim, &device, descending) | ||
| } | ||
| fn argsort_data<B: Backend, K: TensorKind<B> + BasicOps<B>>( | ||
| mut data: TensorData, | ||
| dim: usize, | ||
| device: &Device<B>, | ||
| descending: bool, | ||
| ) -> IntTensor<B> | ||
| where | ||
| <K as BasicOps<B>>::Elem: Element, | ||
| { | ||
| let dims = data.shape.clone(); | ||
| let mut indices_data = dim_indices::<B>(&dims, dim); | ||
| if dims.len() == 1 { | ||
| // 1D sort | ||
| let slice = data.as_slice::<<K as BasicOps<B>>::Elem>().unwrap(); | ||
| indices_data.sort_unstable_by(|&a, &b| { | ||
| compare( | ||
| &slice[a.elem::<i64>() as usize], | ||
| &slice[b.elem::<i64>() as usize], | ||
| descending, | ||
| ) | ||
| }); | ||
| } else { | ||
| sort_slice::<B, K>( | ||
| data.as_mut_slice().unwrap(), | ||
| &dims, | ||
| dim, | ||
| Some(&mut indices_data), | ||
| false, | ||
| descending, | ||
| ); | ||
| } | ||
| B::int_from_data(TensorData::new(indices_data, data.shape), device) | ||
| } | ||
| /// Sort the elements by value along a given dimension. | ||
| /// | ||
| /// When `indices` are not provided, the `data` is sorted. | ||
| /// Otherwise, the `indices` are sorted based on the value of the elements in `data`, | ||
| /// and if `permute_both` is enabled then the data is also sorted. | ||
| /// | ||
| /// This sort is unstable (i.e., may reorder equal elements). | ||
| fn sort_slice<B: Backend, K: BasicOps<B>>( | ||
| data: &mut [<K as BasicOps<B>>::Elem], | ||
| dims: &[usize], | ||
| dim: usize, | ||
| mut indices: Option<&mut [IntElem<B>]>, | ||
| permute_both: bool, | ||
| descending: bool, | ||
| ) where | ||
| <K as BasicOps<B>>::Elem: Element, | ||
| { | ||
| let ndims = dims.len(); | ||
| let strides = compute_strides(dims); | ||
| // Dimensions to access elements to sort | ||
| let mut sort_dims = dims.to_vec(); | ||
| sort_dims[dim] = 1; | ||
| let strides_out = compute_strides(&sort_dims); | ||
| // Number of groups to sort | ||
| let num_sorts: usize = dims | ||
| .iter() | ||
| .enumerate() | ||
| .filter(|&(i, _)| i != dim) | ||
| .map(|(_, d)| d) | ||
| .product(); | ||
| // TODO: run each sort in parallel | ||
| // run_par!(|| { | ||
| // iter_range_par!(0, num_sorts).for_each(|id| {...}) | ||
| for id in 0..num_sorts { | ||
| let mut index_offset = 0; | ||
| let mut stride_dim = 0; | ||
| let mut shape_dim = 0; | ||
| for d in 0..ndims { | ||
| let stride_input = strides[d]; | ||
| let stride_output = strides_out[d]; | ||
| let shape_output = sort_dims[d]; | ||
| let num_block = id / stride_output % shape_output; | ||
| if d != dim { | ||
| index_offset += num_block * stride_input; | ||
| } else { | ||
| let shape_input = dims[d]; | ||
| stride_dim = stride_input; | ||
| shape_dim = shape_input; | ||
| index_offset += num_block; | ||
| } | ||
| } | ||
| // For each group, sort the indices based on the element values | ||
| // NOTE: Sorting methods like `sort_unstable_by` are in-place but we need to sort | ||
| // different views/groups of the underlying data, so the swap is performed on the elements | ||
| // of the (flat index, element value) collection. | ||
| let mut elements = (0..shape_dim) | ||
| .map(|d| { | ||
| let flat_index = d * stride_dim + index_offset; | ||
| let elem = data[flat_index]; | ||
| (d, flat_index, elem) | ||
| }) | ||
| .collect::<Vec<_>>(); | ||
| elements.sort_unstable_by(|&(_, _, a), &(_, _, b)| compare(&a, &b, descending)); | ||
| // Permute data in-place by the sorted indices | ||
| for idx in 0..elements.len() { | ||
| if elements[idx].0 != idx { | ||
| let mut current_idx = idx; | ||
| loop { | ||
| let target_idx = elements[current_idx].0; | ||
| elements[current_idx].0 = current_idx; | ||
| if elements[target_idx].0 == target_idx { | ||
| // correct position | ||
| break; | ||
| } | ||
| if indices.is_none() || permute_both { | ||
| // Permute data by indices | ||
| data.swap(elements[current_idx].1, elements[target_idx].1); | ||
| } | ||
| if let Some(ref mut indices_data) = indices { | ||
| // Permute data element indices | ||
| indices_data.swap(elements[current_idx].1, elements[target_idx].1); | ||
| } | ||
| current_idx = target_idx; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| /// Computes the steps for each dimension when traversing an array. | ||
| fn compute_strides(dims: &[usize]) -> Vec<usize> { | ||
| let mut strides = vec![0; dims.len()]; | ||
| let mut current = 1; | ||
| dims.iter().enumerate().rev().for_each(|(index, val)| { | ||
| strides[index] = current; | ||
| current *= val; | ||
| }); | ||
| strides | ||
| } | ||
| /// Generates the indices for each element along the specified dimension. | ||
| fn dim_indices<B: Backend>(dims: &[usize], dim: usize) -> Vec<IntElem<B>> { | ||
| if dims.len() == 1 { | ||
| (0..dims[dim]) | ||
| .map(|i| (i as i64).elem::<IntElem<B>>()) | ||
| .collect::<Vec<_>>() | ||
| } else { | ||
| // Dimension indices tensor | ||
| let numel_leading_dims: usize = dims[..dim].iter().product(); | ||
| let numel_trailing_dims: usize = dims[dim + 1..].iter().product(); | ||
| (0..dims[dim]) | ||
| .map(|i| [(i as i64).elem::<IntElem<B>>()].repeat(numel_trailing_dims)) | ||
| .collect::<Vec<_>>() | ||
| .concat() | ||
| .repeat(numel_leading_dims) | ||
| } | ||
| } | ||
| /// Compare two elements | ||
| fn compare<E: ElementComparison>(a: &E, b: &E, descending: bool) -> Ordering { | ||
| if descending { b.cmp(a) } else { a.cmp(b) } | ||
| } |
| use alloc::string::String; | ||
| use serde::{Deserialize, Serialize}; | ||
| use crate::tensor::Element; | ||
| use crate::{TensorData, TensorMetadata}; | ||
| use crate::{ops::*, quantization::QTensorPrimitive}; | ||
| use super::DeviceOps; | ||
| /// This trait defines all types and functions needed for a backend to be used with burn. | ||
| /// | ||
| /// ## Design | ||
| /// | ||
| /// This trait aims to be as unopinionated as possible and allows implementations to define | ||
| /// their own types and patterns. Therefore, there are few pre-defined abstractions baked | ||
| /// into this trait. | ||
| /// | ||
| /// Backends must define their own tensor types for each data type: `float`, `int`, and `bool`. | ||
| /// Since we minimize assumptions, we chose to separate these types, as they are used in | ||
| /// different contexts. However, some backends may have a generic tensor type that is used | ||
| /// for all data types. | ||
| /// | ||
| /// ### Eager Mode | ||
| /// | ||
| /// Because burn supports dynamic graphs, the backend trait is designed around kernel | ||
| /// implementations that can be called without any mutable context or graph. This may not be | ||
| /// ideal for backends that want to configure their computational graphs and execute them | ||
| /// multiple times. | ||
| /// | ||
| /// To implement this kind of backend, channels could be used to communicate with a backend | ||
| /// server thread to build the computation graphs and re-execute the ones that are repeated, | ||
| /// with some form of cache. Once that pattern has matured, a graph mode backend trait could | ||
| /// be extracted from it, allowing other backends of the same kind to be quickly integrated | ||
| /// with burn. This pattern could also be used to create an operation fusion trait, which | ||
| /// allows backends to define what kind of graph structures can be fused into one operation. | ||
| /// | ||
| /// ### Multi-Threaded | ||
| /// | ||
| /// Backend tensor types are all `Clone` + `Send`, which allows them to be safely | ||
| /// sent between threads. It is recommended to wrap tensors with [Arc](alloc::sync::Arc), | ||
| /// which avoids copying the tensor's buffer. Note that it is still possible to mutate and | ||
| /// reuse tensors' buffer without locking; see the next section on the Mutable API. | ||
| /// | ||
| /// ### Mutable API | ||
| /// | ||
| /// There is no mutable or inplace operation API to implement, but that does not mean that | ||
| /// backends cannot support them. Using [try_unwrap](alloc::sync::Arc::try_unwrap) and | ||
| /// [get_mut](alloc::sync::Arc::get_mut) allows backends to have access to an owned or mutable | ||
| /// reference to their tensor buffer data structure if the tensor is not shared. In that case, | ||
| /// backends can dispatch to their owned inplace operations for better performance. | ||
| /// | ||
| /// ## Documentation | ||
| /// | ||
| /// Most of the documentation for each function can be found on the user API [tensor struct](crate::Tensor). | ||
| /// For modules, public functions are often created, which can be used by `burn-core` modules. | ||
| pub trait Backend: | ||
| FloatTensorOps<Self> | ||
| + BoolTensorOps<Self> | ||
| + IntTensorOps<Self> | ||
| + ModuleOps<Self> | ||
| + ActivationOps<Self> | ||
| + QTensorOps<Self> | ||
| + TransactionOps<Self> | ||
| + Clone | ||
| + Default | ||
| + Sized | ||
| + Send | ||
| + Sync | ||
| + core::fmt::Debug | ||
| + 'static | ||
| { | ||
| /// Device type. | ||
| type Device: DeviceOps; | ||
| /// Tensor primitive to be used for all float operations. | ||
| type FloatTensorPrimitive: TensorMetadata + 'static; | ||
| /// Default float element type. | ||
| type FloatElem: Element; | ||
| /// Tensor primitive to be used for all int operations. | ||
| type IntTensorPrimitive: TensorMetadata + 'static; | ||
| /// Int element type. | ||
| type IntElem: Element; | ||
| /// Tensor primitive to be used for all bool operations. | ||
| type BoolTensorPrimitive: TensorMetadata + 'static; | ||
| /// Tensor primitive to be used for all bool operations. | ||
| type BoolElem: Element; | ||
| /// Tensor primitive to be used for all quantized operations. | ||
| type QuantizedTensorPrimitive: TensorMetadata + QTensorPrimitive + 'static; | ||
| /// If autodiff is enabled. | ||
| fn ad_enabled() -> bool { | ||
| false | ||
| } | ||
| /// Sets the current allocation mode to persistent. | ||
| #[allow(unused_variables)] | ||
| fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>( | ||
| device: &Self::Device, | ||
| input: Input, | ||
| func: Func, | ||
| ) -> Output { | ||
| func(input) | ||
| } | ||
| /// Manually triggers a memory cleanup on the given device. | ||
| #[allow(unused_variables)] | ||
| fn memory_cleanup(device: &Self::Device) {} | ||
| /// Name of the backend. | ||
| fn name(device: &Self::Device) -> String; | ||
| /// Seeds the backend on the specified device. | ||
| /// | ||
| /// There is no guarantee that only the specified device will be seeded, but it is guaranteed | ||
| /// that at least the specified device will be seeded. | ||
| /// | ||
| /// In all cases, this should ensure deterministic execution for a single-threaded program. | ||
| fn seed(device: &Self::Device, seed: u64); | ||
| /// Sync the backend, ensure that all computation are finished. | ||
| fn sync(_device: &Self::Device) -> Result<(), SyncError> { | ||
| Ok(()) | ||
| } | ||
| /// Marks the given data as being used as a staging buffer for transfer between CPU and | ||
| /// accelerators like GPUs. | ||
| /// | ||
| /// The given data might be transferred to pinned memory or another format to improve data transfer | ||
| /// speed. | ||
| fn staging<'a, Iter>(_data: Iter, _device: &Self::Device) | ||
| where | ||
| Iter: Iterator<Item = &'a mut TensorData>, | ||
| { | ||
| } | ||
| } | ||
| /// An error that can happened when syncing a backend. | ||
| #[derive(Serialize, Deserialize, Debug)] | ||
| pub enum SyncError { | ||
| /// A generic error happened while syncing. | ||
| Generic { | ||
| /// The details | ||
| context: String, | ||
| }, | ||
| /// Syncing the device isn't supported. | ||
| NotSupported { | ||
| /// The details | ||
| context: String, | ||
| }, | ||
| } | ||
| /// An error that can happen when syncing a device. | ||
| #[derive(Serialize, Deserialize, Debug)] | ||
| pub enum ExecutionError { | ||
| /// A generic error happened while syncing. | ||
| Generic { | ||
| /// The details | ||
| context: String, | ||
| }, | ||
| } | ||
| impl core::fmt::Display for SyncError { | ||
| fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { | ||
| match self { | ||
| SyncError::Generic { context } => { | ||
| f.write_fmt(format_args!("An error happened while syncing: {}", context)) | ||
| } | ||
| SyncError::NotSupported { context } => { | ||
| f.write_fmt(format_args!("Can't sync the device: {}", context)) | ||
| } | ||
| } | ||
| } | ||
| } | ||
| /// Trait that allows a backend to support autodiff. | ||
| pub trait AutodiffBackend: Backend { | ||
| /// The inner backend type. | ||
| type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>; | ||
| /// Gradients type. | ||
| type Gradients: Send; | ||
| /// Backward pass. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor is the last node of computational graph where the gradients are computed. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The gradients. | ||
| fn backward(tensor: FloatTensor<Self>) -> Self::Gradients; | ||
| /// Returns the gradients of a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to extract the gradients from. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// An optional tensor containing the gradient. | ||
| fn grad( | ||
| tensor: &FloatTensor<Self>, | ||
| grads: &Self::Gradients, | ||
| ) -> Option<FloatTensor<Self::InnerBackend>>; | ||
| /// Pops the gradients of a tensor and returns them. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to pop the gradients from. | ||
| /// * `grads` - The gradients. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// An optional tensor containing the given gradients. | ||
| fn grad_remove( | ||
| tensor: &FloatTensor<Self>, | ||
| grads: &mut Self::Gradients, | ||
| ) -> Option<FloatTensor<Self::InnerBackend>>; | ||
| /// Replace the gradients of a tensor with the one provided. | ||
| /// | ||
| /// If no gradient existed for the provided tensor, register it. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to pop the gradients from. | ||
| /// * `grads` - The gradients. | ||
| /// * `grad` - The updated grad tensor. | ||
| fn grad_replace( | ||
| tensor: &FloatTensor<Self>, | ||
| grads: &mut Self::Gradients, | ||
| grad: FloatTensor<Self::InnerBackend>, | ||
| ); | ||
| /// Returns the tensor with inner backend type. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the inner backend tensor for. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The inner backend tensor. | ||
| fn inner(tensor: FloatTensor<Self>) -> FloatTensor<Self::InnerBackend>; | ||
| /// Returns the tensor with inner backend type. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the inner backend tensor for. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The inner backend tensor. | ||
| fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend>; | ||
| /// Returns the tensor with inner backend type. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the inner backend tensor for. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The inner backend tensor. | ||
| fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend>; | ||
| /// Returns the tensor with inner backend type. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the inner backend tensor for. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The inner backend tensor. | ||
| fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend>; | ||
| /// Converts the inner backend tensor to the autodiff backend tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The inner backend tensor to convert. | ||
| /// | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The autodiff backend tensor. | ||
| fn from_inner(tensor: FloatTensor<Self::InnerBackend>) -> FloatTensor<Self>; | ||
| /// Converts the inner backend tensor to the autodiff backend tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The inner backend tensor to convert. | ||
| /// | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The autodiff backend tensor. | ||
| fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self>; | ||
| /// Converts the inner backend tensor to the autodiff backend tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The inner backend tensor to convert. | ||
| /// | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The autodiff backend tensor. | ||
| fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self>; | ||
| /// Converts the inner backend tensor to the autodiff backend tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The inner backend tensor to convert. | ||
| /// | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The autodiff backend tensor. | ||
| fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self>; | ||
| } |
| pub use burn_std::device::*; | ||
| /// The handle device trait allows to get an id for a backend device. | ||
| pub trait DeviceOps: | ||
| Clone + Default + PartialEq + Send + Sync + core::fmt::Debug + burn_std::device::Device | ||
| { | ||
| /// Returns the [device id](DeviceId). | ||
| fn id(&self) -> DeviceId { | ||
| self.to_id() | ||
| } | ||
| } |
| mod base; | ||
| mod device; | ||
| pub use base::*; | ||
| pub use device::*; |
| use crate::TensorMetadata; | ||
| use crate::{ElementConversion, backend::Backend}; | ||
| use core::f64::consts::SQRT_2; | ||
| use super::FloatTensor; | ||
| /// Activation function operations. | ||
| /// | ||
| /// This trait let backend implementations override activation functions for better performance. | ||
| pub trait ActivationOps<B: Backend> { | ||
| /// Applies the LeakyReLU activation function. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// * `negative_slope` - The negative_slope value that values smaller than 0 are multiplied with. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The output tensor. | ||
| fn leaky_relu(tensor: FloatTensor<B>, negative_slope: super::FloatElem<B>) -> FloatTensor<B> { | ||
| let mask = B::float_lower_elem(tensor.clone(), 0.elem()); | ||
| let scaled_tensor = B::float_mul_scalar(tensor.clone(), negative_slope.elem()); | ||
| // Update the tensor where the values are `< 0` by `tensor * negative_slope`. | ||
| B::float_mask_where(tensor, mask, scaled_tensor) | ||
| } | ||
| /// Applies the ReLU activation function. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The output tensor. | ||
| fn relu(tensor: FloatTensor<B>) -> FloatTensor<B> { | ||
| let mask = B::float_lower_equal_elem(tensor.clone(), 0.elem()); | ||
| B::float_mask_fill(tensor, mask, 0.elem()) | ||
| } | ||
| /// Applies the ReLU activation function backward. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `output` - The output tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The gradient. | ||
| fn relu_backward(output: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> { | ||
| let mask = B::float_lower_equal_elem(output, 0.elem()); | ||
| B::float_mask_fill(grad, mask, 0.elem()) | ||
| } | ||
| /// Applies the Gelu activation function. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The output tensor. | ||
| fn gelu(tensor: FloatTensor<B>) -> FloatTensor<B> { | ||
| let x = B::float_div_scalar(tensor.clone(), SQRT_2.elem()); | ||
| let x = B::float_erf(x); | ||
| let x = B::float_add_scalar(x, 1i32.elem()); | ||
| let x = B::float_mul(tensor, x); | ||
| B::float_div_scalar(x, 2i32.elem()) | ||
| } | ||
| /// Applies the PReLu activation function. | ||
| /// # Arguments | ||
| /// * `tensor` - The input tensor | ||
| /// * `alpha` - The weight tensor | ||
| fn prelu(tensor: FloatTensor<B>, alpha: FloatTensor<B>) -> FloatTensor<B> { | ||
| let mask = B::float_lower_elem(tensor.clone(), 0.elem()); | ||
| let scaled_tensor = B::float_mul(tensor.clone(), alpha); | ||
| B::float_mask_where(tensor, mask, scaled_tensor) | ||
| } | ||
| /// Applies the Gelu activation function backward. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `x` - The tensor. | ||
| /// * `grad` - The gradient. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The output tensor. | ||
| fn gelu_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> { | ||
| // Derivative of the approximate gelu implementation based on tanh. | ||
| let constant_1 = 0.0356774; | ||
| let constant_2 = 0.797885; | ||
| let constant_3 = 0.0535161; | ||
| let constant_4 = 0.398942; | ||
| let x3 = B::float_powi_scalar(x.clone(), 3.elem()); | ||
| let c1 = B::float_mul_scalar(x3.clone(), constant_1.elem()); | ||
| let c2 = B::float_mul_scalar(x.clone(), constant_2.elem()); | ||
| let c3 = B::float_mul_scalar(x3, constant_3.elem()); | ||
| let c4 = B::float_mul_scalar(x, constant_4.elem()); | ||
| let inner1 = B::float_add(c1, c2); | ||
| let inner2 = B::float_add(c3, c4); | ||
| let tanh = B::float_tanh(inner1); | ||
| let sech = B::float_powi_scalar(tanh.clone(), 2.elem()); | ||
| let sech = B::float_neg(sech); | ||
| let sech = B::float_add_scalar(sech, 1.elem()); | ||
| let y1 = B::float_mul_scalar(tanh, 0.5.elem()); | ||
| let y2 = B::float_mul(inner2, sech); | ||
| let y2 = B::float_add_scalar(y2, 0.5.elem()); | ||
| let y = B::float_add(y1, y2); | ||
| B::float_mul(y, grad) | ||
| } | ||
| /// Applies the Sigmoid activation function. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The output tensor. | ||
| fn sigmoid(tensor: FloatTensor<B>) -> FloatTensor<B> { | ||
| let dtype = tensor.dtype(); | ||
| let tensor_full = B::float_cast(tensor, crate::FloatDType::F32); | ||
| let tensor_tmp = B::float_exp(B::float_neg(B::float_log(B::float_add_scalar( | ||
| B::float_exp(B::float_neg(tensor_full)), | ||
| 1.0.elem(), | ||
| )))); | ||
| B::float_cast(tensor_tmp, dtype.into()) | ||
| } | ||
| /// Applies the Sigmoid activation function backward. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `output` - The output tensor of the sigmoid function. | ||
| /// * `grad` - The gradient. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The output tensor. | ||
| fn sigmoid_backward(output: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> { | ||
| let value = B::float_mul( | ||
| output.clone(), | ||
| B::float_add_scalar(B::float_neg(output), 1.0.elem()), | ||
| ); | ||
| B::float_mul(value, grad) | ||
| } | ||
| /// Applies the hard Sigmoid activation function. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// * `alpha` - The alpha value that the tensor is multiplied with. | ||
| /// * `beta` - The beta value that is added to the tensor | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The output tensor. | ||
| fn hard_sigmoid( | ||
| tensor: FloatTensor<B>, | ||
| alpha: super::FloatElem<B>, | ||
| beta: super::FloatElem<B>, | ||
| ) -> FloatTensor<B> { | ||
| let dtype = tensor.dtype(); | ||
| let tensor_full = B::float_cast(tensor, crate::FloatDType::F32); | ||
| let tensor_tmp = B::float_clamp( | ||
| B::float_add_scalar(B::float_mul_scalar(tensor_full, alpha.elem()), beta.elem()), | ||
| 0.0.elem(), | ||
| 1.0.elem(), | ||
| ); | ||
| B::float_cast(tensor_tmp, dtype.into()) | ||
| } | ||
| /// Applies the LogSigmoid activation function. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The output tensor. | ||
| fn log_sigmoid(tensor: FloatTensor<B>) -> FloatTensor<B> { | ||
| // To avoid overflow, we use the log-sum-exp trick. | ||
| // | ||
| // ```ignore | ||
| // log(sigmoid(x)) = log(1/(1 + exp(-x))) | ||
| // = log(1) - log(1 + exp(-x)) | ||
| // = -log(1 + exp(-x)) | ||
| // = -log(exp(0) + exp(-x)) | ||
| // ``` | ||
| // The `exp(t)` of even a moderate-magnitude positive number can be astronomically huge, so we | ||
| // subtract the `max(t, 0)` of each value (where `t = -x` in this case). This results in the | ||
| // following equivalence: | ||
| // ```ignore | ||
| // log(sigmoid(x)) = -(max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0)))) | ||
| // ``` | ||
| // | ||
| // This extends the range of values for which we obtain accurate results. | ||
| // max(-x, 0) | ||
| let tensor_neg = B::float_neg(tensor); | ||
| let mask = B::float_lower_elem(tensor_neg.clone(), 0.elem()); | ||
| let max_elem = B::float_mask_fill(tensor_neg.clone(), mask, 0.elem()); | ||
| let max_elem_neg = B::float_neg(max_elem.clone()); | ||
| // z = exp(-max(-x, 0)) + exp(-x - max(-x, 0)) | ||
| let z = B::float_add( | ||
| B::float_exp(max_elem_neg.clone()), | ||
| B::float_exp(B::float_sub(tensor_neg, max_elem.clone())), | ||
| ); | ||
| // -max(-x, 0) - log(-z) | ||
| B::float_sub(max_elem_neg, B::float_log(z)) | ||
| } | ||
| /// Applies the LogSigmoid activation function backward. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `x` - The input tensor. | ||
| /// * `grad` - The gradient. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The output gradient. | ||
| fn log_sigmoid_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> { | ||
| // Derivative of -max(-x, 0) - log(exp(-max(-x, 0)) - exp(-x - max(-x, 0)))) is | ||
| // -max_derive - (-max_derive * exp(-max(-x, 0)) + (-1 - max_derive) * exp(-x - max(-x, 0))) / z | ||
| // where z = exp(-max(-x, 0)) + exp(-x - max(-x, 0)) | ||
| // | ||
| // This simplifies to: | ||
| // -max_derive - (z-1)/z if x is >= 0 | ||
| // -max_derive + (z-1)/z if x is < 0 | ||
| let shape = x.shape(); | ||
| let dtype = x.dtype(); | ||
| let device = B::float_device(&x); | ||
| // max(-x, 0) | ||
| let x_neg = B::float_neg(x); | ||
| let mask = B::float_lower_elem(x_neg.clone(), 0.elem()); // -x < 0 or x >= 0 | ||
| let max_elem = B::float_mask_fill(x_neg.clone(), mask.clone(), 0.elem()); | ||
| // z = exp(-max(-x, 0)) + exp(-x - max(-x, 0)) | ||
| let z = B::float_add( | ||
| B::float_exp(B::float_neg(max_elem.clone())), | ||
| B::float_exp(B::float_sub(x_neg, max_elem)), | ||
| ); | ||
| // Derivative of max(-x, 0) is 1 if x < 0 or 0 if x >= 0 | ||
| let ones = B::float_ones(shape, &device, dtype.into()); | ||
| let max_derive = B::float_mask_fill(ones.clone(), mask.clone(), 0.elem()); | ||
| let sign = B::float_mask_fill(ones.clone(), mask, (-1).elem()); | ||
| // grad * (max_derive - sign * (1 - (1 / z))) | ||
| B::float_mul( | ||
| grad, | ||
| B::float_sub( | ||
| max_derive, | ||
| B::float_mul(sign, B::float_sub(ones, B::float_recip(z))), | ||
| ), | ||
| ) | ||
| } | ||
| } |
| use crate::backend::Backend; | ||
| // We provide some type aliases to improve the readability of using associated types without | ||
| // having to use the disambiguation syntax. | ||
| /// Device type used by the backend. | ||
| pub type Device<B> = <B as Backend>::Device; | ||
| /// Float element type used by backend. | ||
| pub type FloatElem<B> = <B as Backend>::FloatElem; | ||
| /// Integer element type used by backend. | ||
| pub type IntElem<B> = <B as Backend>::IntElem; | ||
| /// Float tensor primitive type used by the backend. | ||
| pub type FloatTensor<B> = <B as Backend>::FloatTensorPrimitive; | ||
| /// Integer tensor primitive type used by the backend. | ||
| pub type IntTensor<B> = <B as Backend>::IntTensorPrimitive; | ||
| /// Boolean tensor primitive type used by the backend. | ||
| pub type BoolTensor<B> = <B as Backend>::BoolTensorPrimitive; | ||
| /// Quantized tensor primitive type used by the backend. | ||
| pub type QuantizedTensor<B> = <B as Backend>::QuantizedTensorPrimitive; |
| use super::{ | ||
| BoolTensor, Device, FloatTensor, IntTensor, cat::cat_with_slice_assign, | ||
| repeat_dim::repeat_with_slice_assign, | ||
| }; | ||
| use crate::{ | ||
| Bool, ElementConversion, TensorData, TensorMetadata, argwhere_data, | ||
| backend::{Backend, ExecutionError}, | ||
| tensor::Shape, | ||
| }; | ||
| use alloc::vec::Vec; | ||
| use core::future::Future; | ||
| /// Bool Tensor API for basic operations, see [tensor](crate::Tensor) | ||
| /// for documentation on each function. | ||
| pub trait BoolTensorOps<B: Backend> { | ||
| /// Creates a new bool tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `shape` - The shape of the tensor. | ||
| /// * `device` - The device to create the tensor on. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The boolean tensor with the given shape. | ||
| fn bool_empty(shape: Shape, device: &Device<B>) -> BoolTensor<B>; | ||
| /// Creates a new bool tensor filled false. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `shape` - The shape of the tensor. | ||
| /// * `device` - The device to create the tensor on. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The boolean tensor filled with false. | ||
| fn bool_zeros(shape: Shape, device: &Device<B>) -> BoolTensor<B>; | ||
| /// Creates a new bool tensor filled true. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `shape` - The shape of the tensor. | ||
| /// * `device` - The device to create the tensor on. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The boolean tensor filled with true. | ||
| fn bool_ones(shape: Shape, device: &Device<B>) -> BoolTensor<B>; | ||
| /// Converts the tensor to a data structure. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The data structure with the tensor's data. | ||
| fn bool_into_data( | ||
| tensor: BoolTensor<B>, | ||
| ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send; | ||
| /// Creates a tensor from the data structure. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `data` - The data structure. | ||
| /// * `device` - The device to create the tensor on. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the data. | ||
| fn bool_from_data(data: TensorData, device: &Device<B>) -> BoolTensor<B>; | ||
| /// Converts bool tensor to int tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The int tensor with the same data as the bool tensor. | ||
| fn bool_into_int(tensor: BoolTensor<B>) -> IntTensor<B>; | ||
| /// Converts bool tensor to float tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The float tensor with the same data as the bool tensor. | ||
| fn bool_into_float(tensor: BoolTensor<B>) -> FloatTensor<B>; | ||
| /// Gets the device of the tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The device of the tensor. | ||
| fn bool_device(tensor: &BoolTensor<B>) -> Device<B>; | ||
| /// Moves the tensor to the device. | ||
| fn bool_to_device(tensor: BoolTensor<B>, device: &Device<B>) -> BoolTensor<B>; | ||
| /// Reshapes the tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// * `shape` - The new shape. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the new shape. | ||
| fn bool_reshape(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>; | ||
| /// Gets the values from the tensor for the given ranges. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// * `slices` - The slices specifying ranges and steps for each dimension. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the values for the given slices. | ||
| /// | ||
| /// # Note | ||
| /// | ||
| /// Empty slices (where start >= end) are handled at the high-level tensor API and will not | ||
| /// be passed to this method. Backend implementations do not need to handle empty slices. | ||
| fn bool_slice(tensor: BoolTensor<B>, slices: &[crate::Slice]) -> BoolTensor<B>; | ||
| /// Sets the values in the tensor for the given ranges. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// * `ranges` - The ranges to set the values for. | ||
| /// * `value` - The values to set. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the values set for the given ranges. | ||
| fn bool_slice_assign( | ||
| tensor: BoolTensor<B>, | ||
| slices: &[crate::Slice], | ||
| value: BoolTensor<B>, | ||
| ) -> BoolTensor<B>; | ||
| /// Select tensor elements along the given dimension corresponding to the given indices. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to select from. | ||
| /// * `dim` - The dimension to select from. | ||
| /// * `indices` - The indices of the elements to select. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the selected elements. | ||
| fn bool_select(tensor: BoolTensor<B>, dim: usize, indices: IntTensor<B>) -> BoolTensor<B> { | ||
| // Default implementation: convert to int, select, then convert back to bool | ||
| let int_tensor = B::bool_into_int(tensor); | ||
| let selected = B::int_select(int_tensor, dim, indices); | ||
| B::int_equal_elem(selected, 1_i32.elem()) | ||
| } | ||
| /// Assign the selected elements along the given dimension corresponding to the given indices | ||
| /// to the given value. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to assign the values to. | ||
| /// * `dim` - The dimension to select from. | ||
| /// * `indices` - The indices of the elements to assign. | ||
| /// * `value` - The values to assign. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the assigned values. | ||
| fn bool_select_assign( | ||
| tensor: BoolTensor<B>, | ||
| dim: usize, | ||
| indices: IntTensor<B>, | ||
| value: BoolTensor<B>, | ||
| ) -> BoolTensor<B> { | ||
| // Default implementation: convert to int, select_assign, then convert back to bool | ||
| let int_tensor = B::bool_into_int(tensor); | ||
| let int_values = B::bool_into_int(value); | ||
| let assigned = B::int_select_assign(int_tensor, dim, indices, int_values); | ||
| // After select_assign with sum reduction, any non-zero value should be true | ||
| B::int_greater_elem(assigned, 0_i32.elem()) | ||
| } | ||
| /// Repeats one dimension of the tensor a given number of times along that dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// * `dim` - The dimension to repeat. | ||
| /// * `times` - The number of times to repeat the dimension. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the dimension repeated. | ||
| fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> { | ||
| repeat_with_slice_assign::<B, Bool>(tensor, dim, times) | ||
| } | ||
| /// Concatenates the tensors along the given dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensors` - The tensors to concatenate. | ||
| /// * `dim` - The dimension to concatenate along. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the tensors concatenated along the given dimension. | ||
| fn bool_cat(tensors: Vec<BoolTensor<B>>, dim: usize) -> BoolTensor<B> { | ||
| cat_with_slice_assign::<B, Bool>(tensors, dim) | ||
| } | ||
| /// Equates the two tensors. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left hand side tensor. | ||
| /// * `rhs` - The right hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the result of the equate. | ||
| fn bool_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>; | ||
| /// Element-wise non-equality comparison. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left hand side tensor. | ||
| /// * `rhs` - The right hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the result of the comparison. | ||
| fn bool_not_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> { | ||
| let equal_tensor = B::bool_equal(lhs, rhs); | ||
| B::bool_not(equal_tensor) | ||
| } | ||
| /// Inverses boolean values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the result of the negation. | ||
| fn bool_not(tensor: BoolTensor<B>) -> BoolTensor<B>; | ||
| /// Executes the logical and (`&&`) operation on two boolean tensors. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left hand side tensor. | ||
| /// * `rhs` - The right hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the result of the logical and. | ||
| fn bool_and(tensor: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>; | ||
| /// Executes the logical or (`||`) operation on two boolean tensors. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left hand side tensor. | ||
| /// * `rhs` - The right hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the result of the logical or. | ||
| fn bool_or(tensor: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>; | ||
| /// Element-wise exclusive or. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left hand side tensor. | ||
| /// * `rhs` - The right hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the result of the comparison. | ||
| fn bool_xor(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> { | ||
| Self::bool_not_equal(lhs, rhs) | ||
| } | ||
| /// Transposes a bool tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to transpose. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The transposed tensor. | ||
| fn bool_transpose(tensor: BoolTensor<B>) -> BoolTensor<B> { | ||
| let ndims = tensor.shape().num_dims(); | ||
| Self::bool_swap_dims(tensor, ndims - 2, ndims - 1) | ||
| } | ||
| /// Swaps two dimensions of a bool tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to swap the dimensions of. | ||
| /// * `dim1` - The first dimension to swap. | ||
| /// * `dim2` - The second dimension to swap. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the dimensions swapped. | ||
| fn bool_swap_dims(tensor: BoolTensor<B>, dim1: usize, dim2: usize) -> BoolTensor<B>; | ||
| /// Permutes the dimensions of a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to permute the dimensions of. | ||
| /// * `axes` - The new order of the dimensions. | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the dimensions permuted. | ||
| fn bool_permute(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>; | ||
| /// Reverse the order of elements in a tensor along the given axes. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to reverse. | ||
| /// * `axes` - The axes to reverse. | ||
| /// | ||
| /// The tensor with the elements reversed. | ||
| fn bool_flip(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>; | ||
| /// Tests if any element in the boolean `tensor` evaluates to True. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to test. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise. | ||
| fn bool_any(tensor: BoolTensor<B>) -> BoolTensor<B> { | ||
| let sum = B::int_sum(B::bool_into_int(tensor)); | ||
| B::int_greater_elem(sum, 0.elem()) | ||
| } | ||
| /// Tests if any element in the boolean `tensor` evaluates to True along a given dimension `dim`. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to test. | ||
| /// * `dim` - The axis along which to test. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis | ||
| /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input | ||
| /// evaluates to True, False otherwise. | ||
| fn bool_any_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> { | ||
| let sum = B::int_sum_dim(B::bool_into_int(tensor), dim); | ||
| B::int_greater_elem(sum, 0.elem()) | ||
| } | ||
| /// Tests if all elements in the boolean `tensor` evaluate to True. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to test. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor | ||
| /// evaluate to True, False otherwise. | ||
| fn bool_all(tensor: BoolTensor<B>) -> BoolTensor<B> { | ||
| let num_elems = tensor.shape().num_elements(); | ||
| let sum = B::int_sum(B::bool_into_int(tensor)); | ||
| B::int_equal_elem(sum, (num_elems as i32).elem()) | ||
| } | ||
| /// Tests if all elements in the boolean `tensor` evaluate to True along a given dimension `dim`. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to test. | ||
| /// * `dim` - The axis along which to test. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis | ||
| /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input | ||
| /// evaluates to True, False otherwise. | ||
| fn bool_all_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> { | ||
| let num_elems = tensor.shape().dims[dim]; | ||
| let sum = B::int_sum_dim(B::bool_into_int(tensor), dim); | ||
| B::int_equal_elem(sum, (num_elems as i32).elem()) | ||
| } | ||
| /// Compute the indices of the elements that are non-zero, grouped by element. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The input tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A 2D tensor containing the indices of all non-zero elements of the given tensor. | ||
| /// Each row contains the indices of a non-zero element. | ||
| fn bool_argwhere(tensor: BoolTensor<B>) -> impl Future<Output = IntTensor<B>> + 'static + Send { | ||
| async { | ||
| // Size of each output tensor is variable (= number of nonzero elements in the tensor). | ||
| // Reading the data to count the number of truth values might cause sync but is required. | ||
| let device = B::bool_device(&tensor); | ||
| let data = B::bool_into_data(tensor) | ||
| .await | ||
| .expect("Can read the data without error"); | ||
| argwhere_data::<B>(data, &device) | ||
| } | ||
| } | ||
| /// Broadcasts the bool `tensor` to the given `shape`. | ||
| fn bool_expand(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>; | ||
| /// Unfold windows along a dimension. | ||
| /// | ||
| /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`; | ||
| /// where windows are advanced by `step` at each index. | ||
| /// | ||
| /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]`` | ||
| /// * `dim` - the selected dim. | ||
| /// * `size` - the size of each unfolded window. | ||
| /// * `step` - the step between each window. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor view with shape ``[pre=..., windows, size, post=...]``. | ||
| fn bool_unfold(tensor: BoolTensor<B>, dim: usize, size: usize, step: usize) -> BoolTensor<B>; | ||
| } |
| use super::cat::cat_with_slice_assign; | ||
| use super::repeat_dim::repeat_with_slice_assign; | ||
| use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor}; | ||
| use crate::backend::ExecutionError; | ||
| use crate::{ | ||
| Distribution, ElementConversion, Int, IntDType, TensorData, backend::Backend, tensor::Shape, | ||
| }; | ||
| use crate::{TensorMetadata, argsort, sort, sort_with_indices}; | ||
| use alloc::vec::Vec; | ||
| use core::ops::Range; | ||
| /// Int Tensor API for basic and numeric operations, see [tensor](crate::Tensor) | ||
| /// for documentation on each function. | ||
| pub trait IntTensorOps<B: Backend> { | ||
| /// Creates a new int tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `shape` - The shape of the tensor. | ||
| /// * `device` - The device to create the tensor on. | ||
| /// * `dtype` - The target data type. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The integer tensor with the given shape. | ||
| fn int_empty(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B>; | ||
| /// Converts the tensor to a data structure. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The data structure with the tensor's data. | ||
| fn int_into_data( | ||
| tensor: IntTensor<B>, | ||
| ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send; | ||
| /// Creates a tensor from the data structure. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `data` - The data structure. | ||
| /// * `device` - The device to create the tensor on. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the data. | ||
| fn int_from_data(data: TensorData, device: &Device<B>) -> IntTensor<B>; | ||
| /// Gets the device of the tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The device of the tensor. | ||
| fn int_device(tensor: &IntTensor<B>) -> Device<B>; | ||
| /// Moves the tensor to the given device. | ||
| fn int_to_device(tensor: IntTensor<B>, device: &Device<B>) -> IntTensor<B>; | ||
| /// Reshapes the tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// * `shape` - The new shape. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the new shape. | ||
| fn int_reshape(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>; | ||
| /// Gets the element at the given indices. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// * `slices` - The slices specifying ranges and steps for each dimension. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The elements at the given indices. | ||
| /// | ||
| /// # Note | ||
| /// | ||
| /// Empty slices (where start >= end) are handled at the high-level tensor API and will not | ||
| /// be passed to this method. Backend implementations do not need to handle empty slices. | ||
| fn int_slice(tensor: IntTensor<B>, slices: &[crate::Slice]) -> IntTensor<B>; | ||
| /// Sets the values in the tensor for the given ranges. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// * `ranges` - The ranges to set the values for. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the values set for the given ranges. | ||
| fn int_slice_assign( | ||
| tensor: IntTensor<B>, | ||
| slices: &[crate::Slice], | ||
| value: IntTensor<B>, | ||
| ) -> IntTensor<B>; | ||
| /// Converts int tensor to float tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The int tensor with the same data as the float tensor. | ||
| fn int_into_float(tensor: IntTensor<B>) -> FloatTensor<B>; | ||
| /// Fills the tensor with values from the source tensor if the mask is true at the given | ||
| /// indices. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// * `mask` - The mask. | ||
| /// * `source` - The source tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the values filled. | ||
| fn int_mask_where( | ||
| tensor: IntTensor<B>, | ||
| mask: BoolTensor<B>, | ||
| source: IntTensor<B>, | ||
| ) -> IntTensor<B>; | ||
| /// Fills the tensor with the given value if the mask is true at the given indices. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// * `mask` - The mask. | ||
| /// * `value` - The value. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the values filled. | ||
| fn int_mask_fill(tensor: IntTensor<B>, mask: BoolTensor<B>, value: IntElem<B>) -> IntTensor<B>; | ||
| /// Gather elements from the tensor at the given indices. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `dim` - The dimension to gather from. | ||
| /// * `tensor` - The tensor. | ||
| /// * `indices` - The indices. | ||
| fn int_gather(dim: usize, tensor: IntTensor<B>, indices: IntTensor<B>) -> IntTensor<B>; | ||
| /// Scatter a given value to the tensor at the given indices. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `dim` - The dimension to scatter to. | ||
| /// * `tensor` - The tensor. | ||
| /// * `indices` - The indices. | ||
| /// * `value` - The value. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the values scattered. | ||
| fn int_scatter( | ||
| dim: usize, | ||
| tensor: IntTensor<B>, | ||
| indices: IntTensor<B>, | ||
| value: IntTensor<B>, | ||
| ) -> IntTensor<B>; | ||
| /// Select tensor elements along the given dimension corresponding to the given indices. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// * `dim` - The dimension to select from. | ||
| /// * `indices` - The indices. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the selected elements. | ||
| fn int_select(tensor: IntTensor<B>, dim: usize, indices: IntTensor<B>) -> IntTensor<B>; | ||
| /// Assign the selected elements along the given dimension corresponding to the given indices | ||
| /// to the given value. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// * `dim` - The dimension to select from. | ||
| /// * `indices` - The indices. | ||
| /// * `value` - The value. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the selected elements assigned to the given value. | ||
| fn int_select_assign( | ||
| tensor: IntTensor<B>, | ||
| dim: usize, | ||
| indices: IntTensor<B>, | ||
| value: IntTensor<B>, | ||
| ) -> IntTensor<B>; | ||
| /// Repeats the tensor along the given dimension the given number of times. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// * `dim` - The dimension to repeat. | ||
| /// * `times` - The number of times to repeat. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the given dimension repeated the given number of times. | ||
| fn int_repeat_dim(tensor: IntTensor<B>, dim: usize, times: usize) -> IntTensor<B> { | ||
| repeat_with_slice_assign::<B, Int>(tensor, dim, times) | ||
| } | ||
| /// Concatenates the given tensors along the given dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensors` - The tensors. | ||
| /// * `dim` - The dimension to concatenate along. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The concatenated tensor. | ||
| fn int_cat(tensors: Vec<IntTensor<B>>, dim: usize) -> IntTensor<B> { | ||
| cat_with_slice_assign::<B, Int>(tensors, dim) | ||
| } | ||
| /// Element-wise equality comparison. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The boolean tensor with the result of the comparison. | ||
| fn int_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>; | ||
| /// Element-wise non-equality comparison. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The boolean tensor with the result of the comparison. | ||
| fn int_not_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> { | ||
| let equal_tensor = B::int_equal(lhs, rhs); | ||
| B::bool_not(equal_tensor) | ||
| } | ||
| /// Element-wise equality comparison with a scalar. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The boolean tensor with the result of the comparison. | ||
| fn int_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>; | ||
| /// Element-wise non-equality comparison with a scalar. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The boolean tensor with the result of the comparison. | ||
| fn int_not_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B> { | ||
| let equal_tensor = B::int_equal_elem(lhs, rhs); | ||
| B::bool_not(equal_tensor) | ||
| } | ||
| /// Element-wise greater than comparison. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The boolean tensor with the result of the comparison. | ||
| fn int_greater(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>; | ||
| /// Element-wise greater than comparison with a scalar. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The boolean tensor with the result of the comparison. | ||
| fn int_greater_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>; | ||
| /// Element-wise greater than or equal comparison. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The boolean tensor with the result of the comparison. | ||
| fn int_greater_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>; | ||
| /// Element-wise greater than or equal comparison with a scalar. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The boolean tensor with the result of the comparison. | ||
| fn int_greater_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>; | ||
| /// Element-wise less than comparison. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The boolean tensor with the result of the comparison. | ||
| fn int_lower(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>; | ||
| /// Element-wise less than comparison with a scalar. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The boolean tensor with the result of the comparison. | ||
| fn int_lower_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>; | ||
| /// Element-wise less than or equal comparison. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The boolean tensor with the result of the comparison. | ||
| fn int_lower_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>; | ||
| /// Element-wise less than or equal comparison with a scalar. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The boolean tensor with the result of the comparison. | ||
| fn int_lower_equal_elem(lhs: IntTensor<B>, rhs: IntElem<B>) -> BoolTensor<B>; | ||
| // ==== NUMERIC ==== // | ||
| /// Element-wise addition. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of the addition. | ||
| fn int_add(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>; | ||
| /// Element-wise addition with a scalar. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of the addition. | ||
| fn int_add_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>; | ||
| /// Element-wise power with a IntTensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side IntTensor. | ||
| /// * `rhs` - The right-hand side IntTensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The elements of `lhs` raised to the power of the elements of `rhs`. | ||
| fn int_powi(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> { | ||
| B::float_into_int(B::float_powi(B::int_into_float(lhs), rhs)) | ||
| } | ||
| /// Element-wise power with a floatTensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side floatTensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor. | ||
| fn int_powf(lhs: IntTensor<B>, rhs: FloatTensor<B>) -> IntTensor<B> { | ||
| B::float_into_int(B::float_powf(B::int_into_float(lhs), rhs)) | ||
| } | ||
| /// Element-wise power with a scalar. | ||
| /// | ||
| /// # Backend Implementors Note | ||
| /// | ||
| /// A number of common exponent cases can be implemented with operations | ||
| /// which are much cheaper than generic exponentiation. | ||
| /// | ||
| /// This (`Backend` impl overridable) operation handles generic optimizations | ||
| /// for several common integer exponent cases; and then dispatches to | ||
| /// the (`Backend` impl overridable) [`Self::int_powi_scalar_impl`] | ||
| /// operation to handle the generic case. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The elements of `lhs` raised to the value of `rhs`. | ||
| fn int_powi_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B> { | ||
| let exp = rhs.elem::<i32>(); | ||
| match exp { | ||
| 0 => Self::int_ones(lhs.shape(), &B::int_device(&lhs), lhs.dtype().into()), | ||
| 1 => lhs, | ||
| 2 => Self::int_mul(lhs.clone(), lhs), | ||
| _ => Self::int_powi_scalar_impl(lhs, rhs), | ||
| } | ||
| } | ||
| /// Element-wise power with a scalar. | ||
| /// | ||
| /// # Backend Implementors Note | ||
| /// | ||
| /// This is the generic implementation of integer exponentiation | ||
| /// called by [`Self::int_powi_scalar`] in the fallback case. | ||
| /// | ||
| /// By default, this performs a relatively expensive conversion to float, | ||
| /// exponentiation in float, and conversion back to int. | ||
| /// This reduces the minimal operation set for `Backend`s, | ||
| /// at the cost of performance. | ||
| /// | ||
| /// This is a good target for specialized optimizations in `Backend` implementations. | ||
| /// | ||
| /// As a general rule, this should not be called directly. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The elements of `lhs` raised to the value of `rhs`. | ||
| fn int_powi_scalar_impl(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B> { | ||
| B::float_into_int(B::float_powi_scalar_impl(B::int_into_float(lhs), rhs)) | ||
| } | ||
| /// Element-wise power with a floatTensor. | ||
| /// | ||
| /// Handles a number of special cases, then calls [`Self::int_powf_scalar_impl`]. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor. | ||
| fn int_powf_scalar(lhs: IntTensor<B>, rhs: f32) -> IntTensor<B> { | ||
| if num_traits::Float::floor(rhs) == rhs { | ||
| let exp = B::IntElem::from_elem(rhs as i32); | ||
| Self::int_powi_scalar(lhs, exp) | ||
| } else { | ||
| Self::int_powf_scalar_impl(lhs, rhs) | ||
| } | ||
| } | ||
| /// Element-wise power with a floatTensor. | ||
| /// | ||
| /// Fallback handler for [`Self::int_powf_scalar`]. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor. | ||
| fn int_powf_scalar_impl(lhs: IntTensor<B>, rhs: f32) -> IntTensor<B> { | ||
| B::float_into_int(B::float_powf_scalar_impl(B::int_into_float(lhs), rhs)) | ||
| } | ||
| /// Clamps a tensor under a minimum value. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to clamp. | ||
| /// * `min` - The minimum value. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The clamped tensor. | ||
| fn int_clamp_min(tensor: IntTensor<B>, min: IntElem<B>) -> IntTensor<B> { | ||
| let mask = Self::int_lower_elem(tensor.clone(), min); | ||
| Self::int_mask_fill(tensor, mask, min) | ||
| } | ||
| /// Clamps a tensor over a maximum value. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to clamp. | ||
| /// * `max` - The maximum value. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The clamped tensor. | ||
| fn int_clamp_max(tensor: IntTensor<B>, max: IntElem<B>) -> IntTensor<B> { | ||
| let mask = Self::int_greater_elem(tensor.clone(), max); | ||
| Self::int_mask_fill(tensor, mask, max) | ||
| } | ||
| /// Clamps a tensor between a minimum and maximum value. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to clamp. | ||
| /// * `min` - The minimum value. | ||
| /// * `max` - The maximum value. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The clamped tensor. | ||
| fn int_clamp(tensor: IntTensor<B>, min: IntElem<B>, max: IntElem<B>) -> IntTensor<B> { | ||
| Self::int_clamp_min(Self::int_clamp_max(tensor, max), min) | ||
| } | ||
| /// Element-wise subtraction. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of the subtraction. | ||
| fn int_sub(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>; | ||
| /// Element-wise subtraction with a scalar. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of the subtraction. | ||
| fn int_sub_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>; | ||
| /// Element-wise multiplication. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of the multiplication. | ||
| fn int_mul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>; | ||
| /// Element-wise multiplication with a scalar. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of the multiplication. | ||
| fn int_mul_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>; | ||
| /// Element-wise division. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of the division. | ||
| fn int_div(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>; | ||
| /// Element-wise division with a scalar. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of the division. | ||
| fn int_div_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>; | ||
| /// Element-wise modulus. | ||
| /// | ||
| /// # Arguments | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of applying the modulus of the scalar to the tensor. | ||
| fn int_remainder(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>; | ||
| /// Element-wise modulus with a scalar. | ||
| /// | ||
| /// # Arguments | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of applying the modulus of the scalar to the tensor. | ||
| fn int_remainder_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>; | ||
| /// Multiplies two tensors together using matrix multiplication. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of multiplying the two tensors together using matrix multiplication. | ||
| fn int_matmul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>; | ||
| /// Element-wise negation. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to negate. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The negated tensor. | ||
| fn int_neg(tensor: IntTensor<B>) -> IntTensor<B> { | ||
| Self::int_mul_scalar(tensor, (-1.0).elem::<IntElem<B>>()) | ||
| } | ||
| /// Creates a tensor of zeros. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `shape` - The shape of the tensor. | ||
| /// * `device` - The device to create the tensor on. | ||
| /// * `dtype` - The target data type. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor of zeros. | ||
| fn int_zeros(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B> { | ||
| Self::int_from_data(TensorData::full_dtype(shape, 0, dtype.into()), device) | ||
| } | ||
| /// Creates a tensor of ones. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `shape` - The shape of the tensor. | ||
| /// * `device` - The device to create the tensor on. | ||
| /// * `dtype` - The target data type. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor of ones. | ||
| fn int_ones(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B> { | ||
| Self::int_from_data(TensorData::full_dtype(shape, 1, dtype.into()), device) | ||
| } | ||
| /// Creates a tensor filled with given value. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `shape` - The shape of the tensor. | ||
| /// * `fill_value` - The value with which to fill the tensor. | ||
| /// * `device` - The device to create the tensor on. | ||
| /// * `dtype` - The target data type. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor filled with given value | ||
| fn int_full( | ||
| shape: Shape, | ||
| fill_value: IntElem<B>, | ||
| device: &Device<B>, | ||
| dtype: IntDType, | ||
| ) -> IntTensor<B> { | ||
| Self::int_from_data( | ||
| TensorData::full_dtype(shape, fill_value, dtype.into()), | ||
| device, | ||
| ) | ||
| } | ||
| /// Sums all elements in the tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to sum. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The sum of all elements in the tensor. | ||
| fn int_sum(tensor: IntTensor<B>) -> IntTensor<B>; | ||
| /// Sums all elements in the tensor along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to sum. | ||
| /// * `dim` - The dimension to sum along. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The sum of all elements in the tensor along the dimension. | ||
| fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>; | ||
| /// Computes the product of all elements in the tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to compute the product of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The product of all elements in the tensor. | ||
| fn int_prod(tensor: IntTensor<B>) -> IntTensor<B>; | ||
| /// Computes the product of all elements in the tensor along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to compute the product of. | ||
| /// * `dim` - The dimension to compute the product along. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The product of all elements in the tensor along the dimension. | ||
| fn int_prod_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>; | ||
| /// Computes the mean of all elements in the tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to compute the mean of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The mean of all elements in the tensor. | ||
| fn int_mean(tensor: IntTensor<B>) -> IntTensor<B> { | ||
| let num_elems = tensor.shape().num_elements(); | ||
| B::int_div_scalar(B::int_sum(tensor), (num_elems as i64).elem()) | ||
| } | ||
| /// Computes the mean of all elements in the tensor along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to compute the mean of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The mean of all elements in the tensor along the dimension. | ||
| fn int_mean_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>; | ||
| /// Computes the cumulative sum of elements along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to compute the cumulative sum of. | ||
| /// * `dim` - The dimension along which to compute the cumulative sum. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape where each element is the cumulative sum | ||
| /// of all elements up to and including that position along the dimension. | ||
| fn int_cumsum(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>; | ||
| /// Computes the cumulative product of elements along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to compute the cumulative product of. | ||
| /// * `dim` - The dimension along which to compute the cumulative product. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape where each element is the cumulative product | ||
| /// of all elements up to and including that position along the dimension. | ||
| fn int_cumprod(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>; | ||
| /// Computes the cumulative minimum of elements along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to compute the cumulative minimum of. | ||
| /// * `dim` - The dimension along which to compute the cumulative minimum. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape where each element is the minimum | ||
| /// of all elements up to and including that position along the dimension. | ||
| fn int_cummin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>; | ||
| /// Computes the cumulative maximum of elements along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to compute the cumulative maximum of. | ||
| /// * `dim` - The dimension along which to compute the cumulative maximum. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape where each element is the maximum | ||
| /// of all elements up to and including that position along the dimension. | ||
| fn int_cummax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>; | ||
| /// Gets the indices of the maximum elements along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the maximum indices of. | ||
| /// * `dim` - The dimension to get the maximum indices along. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The indices of the maximum elements along the dimension. | ||
| fn int_argmax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>; | ||
| /// Gets the indices of the minimum elements along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the minimum indices of. | ||
| /// * `dim` - The dimension to get the minimum indices along. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The indices of the minimum elements along the dimension. | ||
| fn int_argmin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>; | ||
| /// Gets the maximum element in the tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the maximum element of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The maximum element in the tensor. | ||
| fn int_max(tensor: IntTensor<B>) -> IntTensor<B> { | ||
| let shape = tensor.shape(); | ||
| let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()])); | ||
| B::int_max_dim(tensor, 0) | ||
| } | ||
| /// Gets the maximum element in the tensor along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the maximum element of. | ||
| /// * `dim` - The dimension to get the maximum element along. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The maximum element in the tensor along the dimension. | ||
| fn int_max_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> { | ||
| let index = B::int_argmax(tensor.clone(), dim); | ||
| B::int_gather(dim, tensor, index) | ||
| } | ||
| /// Gets the maximum elements and corresponding indices along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the maximum elements and indices of. | ||
| /// * `dim` - The dimension to get the maximum elements and indices along. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The maximum elements and corresponding indices along the dimension. | ||
| fn int_max_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) { | ||
| let index = B::int_argmax(tensor.clone(), dim); | ||
| let values = B::int_gather(dim, tensor, index.clone()); | ||
| (values, index) | ||
| } | ||
| /// Gets the maximum absolute element in the tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the maximum element of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The maximum element in the tensor. | ||
| fn int_max_abs(tensor: IntTensor<B>) -> IntTensor<B> { | ||
| let shape = tensor.shape(); | ||
| let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()])); | ||
| B::int_max_abs_dim(tensor, 0) | ||
| } | ||
| /// Gets the maximum absolute element in the tensor along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the maximum element of. | ||
| /// * `dim` - The dimension to get the maximum element along. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The maximum element in the tensor along the dimension. | ||
| fn int_max_abs_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> { | ||
| B::int_max_dim(B::int_abs(tensor), dim) | ||
| } | ||
| /// Gets the minimum element in the tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the minimum element of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The minimum element in the tensor. | ||
| fn int_min(tensor: IntTensor<B>) -> IntTensor<B> { | ||
| let shape = tensor.shape(); | ||
| let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()])); | ||
| B::int_min_dim(tensor, 0) | ||
| } | ||
| /// Gets the minimum elements in the tensor along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the minimum element of. | ||
| /// * `dim` - The dimension to get the minimum element along. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The minimum element in the tensor along the dimension. | ||
| fn int_min_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> { | ||
| let index = B::int_argmin(tensor.clone(), dim); | ||
| B::int_gather(dim, tensor, index) | ||
| } | ||
| /// Gets the minimum elements and corresponding indices along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the minimum elements and indices of. | ||
| /// * `dim` - The dimension to get the minimum elements and indices along. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The minimum elements and corresponding indices along the dimension. | ||
| fn int_min_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) { | ||
| let indices = B::int_argmin(tensor.clone(), dim); | ||
| let values = B::int_gather(dim, tensor, indices.clone()); | ||
| (values, indices) | ||
| } | ||
| /// Returns a new tensor with absolute values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to take absolute value of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with absolute values. | ||
| fn int_abs(tensor: IntTensor<B>) -> IntTensor<B>; | ||
| /// Transposes an int tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to transpose. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The transposed tensor. | ||
| fn int_transpose(tensor: IntTensor<B>) -> IntTensor<B> { | ||
| let ndims = tensor.shape().num_dims(); | ||
| Self::int_swap_dims(tensor, ndims - 2, ndims - 1) | ||
| } | ||
| /// Swaps two dimensions of an int tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to swap the dimensions of. | ||
| /// * `dim1` - The first dimension to swap. | ||
| /// * `dim2` - The second dimension to swap. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the dimensions swapped. | ||
| fn int_swap_dims(tensor: IntTensor<B>, dim1: usize, dim2: usize) -> IntTensor<B>; | ||
| /// Permutes the dimensions of a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to permute the dimensions of. | ||
| /// * `axes` - The new order of the dimensions. | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the dimensions permuted. | ||
| fn int_permute(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>; | ||
| /// Reverse the order of elements in a tensor along the given axes. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to reverse. | ||
| /// * `axes` - The axes to reverse. | ||
| /// | ||
| /// The tensor with the elements reversed. | ||
| fn int_flip(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>; | ||
| /// Creates a new int tensor with random values. | ||
| /// | ||
| /// # Arguments | ||
| /// * `shape` - The shape of the tensor. | ||
| /// * `distribution` - The distribution to sample from. | ||
| /// * `device` - The device to create the tensor on. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the given shape and random values. | ||
| fn int_random(shape: Shape, distribution: Distribution, device: &Device<B>) -> IntTensor<B>; | ||
| /// Creates a new tensor with values from the given range with the given step size. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `range` - The range of values. | ||
| /// * `step` - The step size. | ||
| /// * `device` - The device to create the tensor on. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the given values. | ||
| fn int_arange_step(range: Range<i64>, step: usize, device: &Device<B>) -> IntTensor<B> { | ||
| let value = range | ||
| .step_by(step) | ||
| .map(|i| i.elem()) | ||
| .collect::<Vec<IntElem<B>>>(); | ||
| let shape = Shape::new([value.len()]); | ||
| let data = TensorData::new(value, shape); | ||
| B::int_from_data(data, device) | ||
| } | ||
| /// Creates a new tensor with values from the given range. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `range` - The range of values. | ||
| /// * `device` - The device to create the tensor on. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the given values. | ||
| /// | ||
| /// # Remarks | ||
| /// | ||
| /// Uses `arange_step` with a step size of 1 under the hood. | ||
| fn int_arange(range: Range<i64>, device: &Device<B>) -> IntTensor<B> { | ||
| Self::int_arange_step(range, 1, device) | ||
| } | ||
| /// Tests if any element in the int `tensor` evaluates to True. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to test. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise. | ||
| fn int_any(tensor: IntTensor<B>) -> BoolTensor<B> { | ||
| let bool_tensor = B::int_equal_elem(tensor, 0.elem()); | ||
| let bool_tensor = B::bool_not(bool_tensor); | ||
| let sum = B::int_sum(B::bool_into_int(bool_tensor)); | ||
| B::int_greater_elem(sum, 0.elem()) | ||
| } | ||
| /// Tests if any element in the int `tensor` evaluates to True along a given dimension `dim`. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to test. | ||
| /// * `dim` - The axis along which to test. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis | ||
| /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input | ||
| /// evaluates to True, False otherwise. | ||
| fn int_any_dim(tensor: IntTensor<B>, dim: usize) -> BoolTensor<B> { | ||
| let bool_tensor = B::int_equal_elem(tensor, 0.elem()); | ||
| let bool_tensor = B::bool_not(bool_tensor); | ||
| let sum = B::int_sum_dim(B::bool_into_int(bool_tensor), dim); | ||
| B::int_greater_elem(sum, 0.elem()) | ||
| } | ||
| /// Tests if all elements in the int `tensor` evaluate to True. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to test. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor | ||
| /// evaluate to True, False otherwise. | ||
| fn int_all(tensor: IntTensor<B>) -> BoolTensor<B> { | ||
| let num_elems = tensor.shape().num_elements(); | ||
| let bool_tensor = B::int_equal_elem(tensor, 0.elem()); | ||
| let bool_tensor = B::bool_not(bool_tensor); | ||
| let sum = B::int_sum(B::bool_into_int(bool_tensor)); | ||
| B::int_equal_elem(sum, (num_elems as i32).elem()) | ||
| } | ||
| /// Tests if all elements in the int `tensor` evaluate to True along a given dimension `dim`. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to test. | ||
| /// * `dim` - The axis along which to test. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis | ||
| /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input | ||
| /// evaluates to True, False otherwise. | ||
| fn int_all_dim(tensor: IntTensor<B>, dim: usize) -> BoolTensor<B> { | ||
| let num_elems = tensor.shape().dims[dim]; | ||
| let bool_tensor = B::int_equal_elem(tensor, 0.elem()); | ||
| let bool_tensor = B::bool_not(bool_tensor); | ||
| let sum = B::int_sum_dim(B::bool_into_int(bool_tensor), dim); | ||
| B::int_equal_elem(sum, (num_elems as i32).elem()) | ||
| } | ||
| /// Returns the signs of the int `tensor`. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to extract the signs from. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`. | ||
| fn int_sign(tensor: IntTensor<B>) -> IntTensor<B> { | ||
| let dtype = tensor.dtype(); | ||
| let zeros = B::int_zeros(tensor.shape(), &B::int_device(&tensor), dtype.into()); | ||
| let less_than_zero = B::int_lower_elem(tensor.clone(), 0.0f32.elem()); | ||
| let greater_than_zero = B::int_greater_elem(tensor, 0.0f32.elem()); | ||
| let mut result = B::int_mask_fill(zeros, less_than_zero, (-1.0f32).elem()); | ||
| result = B::int_mask_fill(result, greater_than_zero, 1.0f32.elem()); | ||
| result | ||
| } | ||
| /// Broadcasts the int `tensor` to the given `shape`. | ||
| fn int_expand(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>; | ||
| /// Sort the elements of the input `tensor` by value along a given dimension. | ||
| /// | ||
| /// This sort is unstable (i.e., may reorder equal elements). | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The input tensor. | ||
| /// * `dim` - The axis along which to sort. | ||
| /// * `descending` - The sorting order. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as the input tensor, where the elements are sorted by value. | ||
| fn int_sort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> { | ||
| sort::<B, Int>(tensor, dim, descending) | ||
| } | ||
| /// Sort the elements of the input `tensor` by value along a given dimension. | ||
| /// | ||
| /// This sort is unstable (i.e., may reorder equal elements). | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The input tensor. | ||
| /// * `dim` - The axis along which to sort. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as the input tensor and corresponding indices, where | ||
| /// the elements are sorted by value and the indices map back to the original input tensor. | ||
| fn int_sort_with_indices( | ||
| tensor: IntTensor<B>, | ||
| dim: usize, | ||
| descending: bool, | ||
| ) -> (IntTensor<B>, IntTensor<B>) { | ||
| sort_with_indices::<B, Int>(tensor, dim, descending) | ||
| } | ||
| /// Returns the indices that sort the elements of the input `tensor` by value | ||
| /// along a given dimension. | ||
| /// | ||
| /// This sort is unstable (i.e., may reorder equal elements). | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The input tensor. | ||
| /// * `dim` - The axis along which to sort. | ||
| /// * `descending` - The sorting order. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as the input tensor the indices map back to the original input tensor. | ||
| fn int_argsort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> { | ||
| argsort::<B, Int>(tensor, dim, descending) | ||
| } | ||
| /// Bitwise AND operation for Int Tensors | ||
| fn bitwise_and(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>; | ||
| /// Bitwise AND operation for Int Tensors with a scalar | ||
| fn bitwise_and_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>; | ||
| /// Bitwise OR operation for Int Tensors | ||
| fn bitwise_or(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>; | ||
| /// Bitwise OR operation for Int Tensors with a scalar | ||
| fn bitwise_or_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>; | ||
| /// Bitwise XOR operation for Int Tensors | ||
| fn bitwise_xor(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>; | ||
| /// Bitwise XOR operation for Int Tensors with a scalar | ||
| fn bitwise_xor_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>; | ||
| /// Bitwise NOT operation for Int Tensors | ||
| fn bitwise_not(tensor: IntTensor<B>) -> IntTensor<B>; | ||
| /// Bitwise left shift operation for Int Tensors | ||
| fn bitwise_left_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>; | ||
| /// Bitwise left shift operation for Int Tensors with a scalar | ||
| fn bitwise_left_shift_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>; | ||
| /// Bitwise right shift operation for Int Tensors | ||
| fn bitwise_right_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>; | ||
| /// Bitwise right shift operation for Int Tensors with a scalar | ||
| fn bitwise_right_shift_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>; | ||
| /// Converts a tensor to another integer data type. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to convert. | ||
| /// * `dtype` - The target data type. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same values as `tensor` but in the target integer data type. | ||
| fn int_cast(tensor: IntTensor<B>, dtype: IntDType) -> IntTensor<B>; | ||
| /// Unfold windows along a dimension. | ||
| /// | ||
| /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`; | ||
| /// where windows are advanced by `step` at each index. | ||
| /// | ||
| /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]`` | ||
| /// * `dim` - the selected dim. | ||
| /// * `size` - the size of each unfolded window. | ||
| /// * `step` - the step between each window. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor view with shape ``[pre=..., windows, size, post=...]``. | ||
| fn int_unfold(tensor: IntTensor<B>, dim: usize, size: usize, step: usize) -> IntTensor<B>; | ||
| } |
| mod activation; | ||
| mod alias; | ||
| mod bool_tensor; | ||
| mod int_tensor; | ||
| mod modules; | ||
| mod qtensor; | ||
| mod tensor; | ||
| mod transaction; | ||
| pub use activation::*; | ||
| pub use alias::*; | ||
| pub use bool_tensor::*; | ||
| pub use int_tensor::*; | ||
| pub use modules::*; | ||
| pub use qtensor::*; | ||
| pub use tensor::*; | ||
| pub use transaction::*; |
| use super::{conv, pool}; | ||
| use crate::ops::unfold::unfold4d_using_conv2d; | ||
| use crate::{ | ||
| Shape, TensorMetadata, | ||
| backend::Backend, | ||
| ops::{FloatTensor, IntTensor}, | ||
| }; | ||
| use core::num::NonZeroUsize; | ||
| /// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d). | ||
| #[derive(new)] | ||
| pub struct Conv2dBackward<B: Backend> { | ||
| /// Gradient. | ||
| pub x_grad: FloatTensor<B>, | ||
| /// Weights gradient. | ||
| pub weights_grad: FloatTensor<B>, | ||
| /// Bias gradient. | ||
| pub bias_grad: Option<FloatTensor<B>>, | ||
| } | ||
| /// Gradient computed during the backward pass for each tensor used by [deform_conv2d](ModuleOps::deform_conv2d). | ||
| #[derive(new)] | ||
| pub struct DeformConv2dBackward<B: Backend> { | ||
| /// Gradient. | ||
| pub x_grad: FloatTensor<B>, | ||
| /// Offset gradient. | ||
| pub offset_grad: FloatTensor<B>, | ||
| /// Weights gradient. | ||
| pub weight_grad: FloatTensor<B>, | ||
| /// Mask gradient. | ||
| pub mask_grad: Option<FloatTensor<B>>, | ||
| /// Bias gradient. | ||
| pub bias_grad: Option<FloatTensor<B>>, | ||
| } | ||
| /// Gradient computed during the backward pass for each tensor used by [conv3d](ModuleOps::conv3d). | ||
| #[derive(new)] | ||
| pub struct Conv3dBackward<B: Backend> { | ||
| /// Gradient. | ||
| pub x_grad: FloatTensor<B>, | ||
| /// Weights gradient. | ||
| pub weights_grad: FloatTensor<B>, | ||
| /// Bias gradient. | ||
| pub bias_grad: Option<FloatTensor<B>>, | ||
| } | ||
| /// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d). | ||
| #[derive(new)] | ||
| pub struct MaxPool1dBackward<B: Backend> { | ||
| /// Gradient. | ||
| pub x_grad: FloatTensor<B>, | ||
| } | ||
| /// Results from [max_pool1d](ModuleOps::max_pool1d_with_indices). | ||
| #[derive(new)] | ||
| pub struct MaxPool1dWithIndices<B: Backend> { | ||
| /// The output tensor. | ||
| pub output: FloatTensor<B>, | ||
| /// The indices tensor. | ||
| pub indices: IntTensor<B>, | ||
| } | ||
| /// Gradient computed during the backward pass for each tensor used by [max_pool2d](ModuleOps::max_pool2d). | ||
| #[derive(new)] | ||
| pub struct MaxPool2dBackward<B: Backend> { | ||
| /// Gradient. | ||
| pub x_grad: FloatTensor<B>, | ||
| } | ||
| /// Results from [max_pool2d](ModuleOps::max_pool2d_with_indices). | ||
| #[derive(new)] | ||
| pub struct MaxPool2dWithIndices<B: Backend> { | ||
| /// The output tensor. | ||
| pub output: FloatTensor<B>, | ||
| /// The indices tensor. | ||
| pub indices: IntTensor<B>, | ||
| } | ||
| /// Check that the parameter value is non-zero. | ||
| // NOTE: for now we keep usize but we could refactor the parameters to hold `NonZeroUsize`. | ||
| pub(crate) fn check_nonzero(value: usize, msg: &str) -> usize { | ||
| NonZeroUsize::new(value).expect(msg); | ||
| value | ||
| } | ||
| /// Convolution options. | ||
| #[derive(Debug, Clone, Hash, PartialEq, Eq)] | ||
| pub struct ConvOptions<const N: usize> { | ||
| /// Stride (non-zero). | ||
| pub stride: [usize; N], | ||
| /// Padding. | ||
| pub padding: [usize; N], | ||
| /// Dilation (non-zero). | ||
| pub dilation: [usize; N], | ||
| /// Groups (non-zero). | ||
| pub groups: usize, | ||
| } | ||
| impl<const N: usize> ConvOptions<N> { | ||
| /// Constructs a new `ConvOptions`. | ||
| pub fn new( | ||
| stride: [usize; N], | ||
| padding: [usize; N], | ||
| dilation: [usize; N], | ||
| groups: usize, | ||
| ) -> Self { | ||
| Self { | ||
| stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")), | ||
| padding, | ||
| dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")), | ||
| groups: check_nonzero(groups, "groups must be non-zero"), | ||
| } | ||
| } | ||
| } | ||
| /// Convolution options. | ||
| #[derive(Debug, Clone, Hash, PartialEq, Eq)] | ||
| pub struct DeformConvOptions<const N: usize> { | ||
| /// Stride (non-zero). | ||
| pub stride: [usize; N], | ||
| /// Padding. | ||
| pub padding: [usize; N], | ||
| /// Dilation (non-zero). | ||
| pub dilation: [usize; N], | ||
| /// Weight Groups (non-zero). | ||
| pub weight_groups: usize, | ||
| /// Offset Groups (non-zero). | ||
| pub offset_groups: usize, | ||
| } | ||
| impl<const N: usize> DeformConvOptions<N> { | ||
| /// Constructs a new `DeformConvOptions`. | ||
| pub fn new( | ||
| stride: [usize; N], | ||
| padding: [usize; N], | ||
| dilation: [usize; N], | ||
| weight_groups: usize, | ||
| offset_groups: usize, | ||
| ) -> Self { | ||
| Self { | ||
| stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")), | ||
| padding, | ||
| dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")), | ||
| weight_groups: check_nonzero(weight_groups, "weight groups must be non-zero"), | ||
| offset_groups: check_nonzero(offset_groups, "offset groups must be non-zero"), | ||
| } | ||
| } | ||
| } | ||
| /// Transposed convolution options. | ||
| #[derive(Debug, Clone, Hash, PartialEq, Eq)] | ||
| pub struct ConvTransposeOptions<const N: usize> { | ||
| /// Stride (non-zero). | ||
| pub stride: [usize; N], | ||
| /// Padding. | ||
| pub padding: [usize; N], | ||
| /// Padding out. | ||
| pub padding_out: [usize; N], | ||
| /// Dilation (non-zero). | ||
| pub dilation: [usize; N], | ||
| /// Groups (non-zero). | ||
| pub groups: usize, | ||
| } | ||
| impl<const N: usize> ConvTransposeOptions<N> { | ||
| /// Constructs a new `ConvTransposeOptions`. | ||
| pub fn new( | ||
| stride: [usize; N], | ||
| padding: [usize; N], | ||
| padding_out: [usize; N], | ||
| dilation: [usize; N], | ||
| groups: usize, | ||
| ) -> Self { | ||
| Self { | ||
| stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")), | ||
| padding, | ||
| padding_out, | ||
| dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")), | ||
| groups: check_nonzero(groups, "groups must be non-zero"), | ||
| } | ||
| } | ||
| } | ||
| /// Unfold operation options. | ||
| #[derive(Debug, Clone)] | ||
| pub struct UnfoldOptions { | ||
| /// The number of positions to slide over the input tensor in each dimension. | ||
| /// A stride of `[1, 1]` will slide the kernel one pixel at a time. | ||
| pub stride: [usize; 2], | ||
| /// The number of zero-padding pixels added to each side of the input tensor in each dimension. | ||
| pub padding: [usize; 2], | ||
| /// The spacing between the blocks (patches) in the original input tensor. | ||
| pub dilation: [usize; 2], | ||
| } | ||
| impl UnfoldOptions { | ||
| /// Constructs a new `UnfoldOptions`. | ||
| pub fn new(stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2]) -> Self { | ||
| Self { | ||
| stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")), | ||
| padding, | ||
| dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")), | ||
| } | ||
| } | ||
| } | ||
| /// Algorithm used for upsampling. | ||
| #[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)] | ||
| pub enum InterpolateMode { | ||
| /// Nearest-neighbor interpolation. | ||
| /// <https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation> | ||
| Nearest, | ||
| /// Bilinear interpolation. | ||
| /// <https://en.wikipedia.org/wiki/Bilinear_interpolation> | ||
| Bilinear, | ||
| /// Bicubic interpolation. | ||
| /// <https://en.wikipedia.org/wiki/Bicubic_interpolation> | ||
| Bicubic, | ||
| } | ||
| /// Interpolation options. | ||
| #[derive(new, Debug, Clone)] | ||
| pub struct InterpolateOptions { | ||
| /// Algorithm used for upsampling. | ||
| pub mode: InterpolateMode, | ||
| } | ||
| /// Padding mode for grid sampling when coordinates are out of bounds. | ||
| /// | ||
| /// Matches PyTorch's `padding_mode` parameter in `grid_sample`. | ||
| #[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Deserialize, serde::Serialize)] | ||
| pub enum GridSamplePaddingMode { | ||
| /// Fill with zeros for out-of-bounds coordinates. | ||
| #[default] | ||
| Zeros, | ||
| /// Clamp coordinates to the border (use nearest edge value). | ||
| Border, | ||
| /// Reflect coordinates at the boundary. | ||
| Reflection, | ||
| } | ||
| /// Options for grid sampling operations. | ||
| #[derive(Debug, Clone)] | ||
| pub struct GridSampleOptions { | ||
| /// Interpolation mode (bilinear, nearest, or bicubic). | ||
| pub mode: InterpolateMode, | ||
| /// Padding mode for out-of-bounds coordinates. | ||
| pub padding_mode: GridSamplePaddingMode, | ||
| /// If `true`, grid values of -1 and 1 correspond to the corner pixels. | ||
| /// If `false`, they correspond to the corner points of the corner pixels | ||
| /// (i.e., -1 maps to -0.5 and 1 maps to size - 0.5 in pixel coordinates). | ||
| pub align_corners: bool, | ||
| } | ||
| impl Default for GridSampleOptions { | ||
| fn default() -> Self { | ||
| Self { | ||
| mode: InterpolateMode::Bilinear, | ||
| padding_mode: GridSamplePaddingMode::Zeros, | ||
| align_corners: false, | ||
| } | ||
| } | ||
| } | ||
| impl GridSampleOptions { | ||
| /// Create new grid sample options with the given interpolation mode. | ||
| /// | ||
| /// Uses default values for padding_mode (Zeros) and align_corners (false). | ||
| pub fn new(mode: InterpolateMode) -> Self { | ||
| Self { | ||
| mode, | ||
| ..Default::default() | ||
| } | ||
| } | ||
| /// Set the padding mode. | ||
| pub fn with_padding_mode(mut self, padding_mode: GridSamplePaddingMode) -> Self { | ||
| self.padding_mode = padding_mode; | ||
| self | ||
| } | ||
| /// Set align_corners. | ||
| pub fn with_align_corners(mut self, align_corners: bool) -> Self { | ||
| self.align_corners = align_corners; | ||
| self | ||
| } | ||
| } | ||
| /// Gradient computed during the backward pass for each tensor used by [interpolate](ModuleOps::interpolate). | ||
| #[derive(new)] | ||
| pub struct InterpolateBackward<B: Backend> { | ||
| /// Gradient. | ||
| pub x_grad: FloatTensor<B>, | ||
| } | ||
| /// Module operations trait. | ||
| pub trait ModuleOps<B: Backend> { | ||
| /// Embedding operation. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `weights` - The embedding weights. | ||
| /// * `indices` - The indices tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The output tensor. | ||
| fn embedding(weights: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B> { | ||
| let [batch_size, seq_length] = indices.shape().dims(); | ||
| let [_, d_model] = weights.shape().dims(); | ||
| let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length])); | ||
| let output = B::float_select(weights, 0, indices); | ||
| B::float_reshape(output, Shape::new([batch_size, seq_length, d_model])) | ||
| } | ||
| /// Embedding backward operation. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `weights` - The embedding weights. | ||
| /// * `output_grad` - The output gradient. | ||
| /// * `indices` - The indices tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The gradient. | ||
| fn embedding_backward( | ||
| weights: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| indices: IntTensor<B>, | ||
| ) -> FloatTensor<B> { | ||
| let [batch_size, seq_length] = indices.shape().dims(); | ||
| let [n_embeddings, d_model] = weights.shape().dims(); | ||
| let device = B::float_device(&weights); | ||
| let dtype = output_grad.dtype(); | ||
| let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length])); | ||
| let output_grad = | ||
| B::float_reshape(output_grad, Shape::new([batch_size * seq_length, d_model])); | ||
| let grad = B::float_zeros(Shape::new([n_embeddings, d_model]), &device, dtype.into()); | ||
| B::float_select_assign(grad, 0, indices, output_grad) | ||
| } | ||
| /// One dimensional convolution. | ||
| /// | ||
| /// # Shapes | ||
| /// | ||
| /// x: `[batch_size, channels_in, length]`, | ||
| /// weight: `[channels_out, channels_in, kernel_size]`, | ||
| /// bias: `[channels_out]`, | ||
| fn conv1d( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| bias: Option<FloatTensor<B>>, | ||
| options: ConvOptions<1>, | ||
| ) -> FloatTensor<B> { | ||
| conv::conv1d_from_conv2d::<B>(x, weight, bias, options) | ||
| } | ||
| /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `x`. | ||
| fn conv1d_x_backward( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvOptions<1>, | ||
| ) -> FloatTensor<B> { | ||
| conv::conv1d_x_backward::<B>(x, weight, output_grad, options) | ||
| } | ||
| /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `weight`. | ||
| fn conv1d_weight_backward( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvOptions<1>, | ||
| ) -> FloatTensor<B> { | ||
| conv::conv1d_weight_backward::<B>(x, weight, output_grad, options) | ||
| } | ||
| /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `bias`. | ||
| fn conv1d_bias_backward( | ||
| x: FloatTensor<B>, | ||
| bias: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| ) -> FloatTensor<B> { | ||
| conv::conv1d_bias_backward::<B>(x, bias, output_grad) | ||
| } | ||
| /// Two dimensional convolution. | ||
| /// | ||
| /// # Shapes | ||
| /// | ||
| /// x: `[batch_size, channels_in, height, width]`, | ||
| /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`, | ||
| /// bias: `[channels_out]`, | ||
| fn conv2d( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| bias: Option<FloatTensor<B>>, | ||
| options: ConvOptions<2>, | ||
| ) -> FloatTensor<B>; | ||
| /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `x`. | ||
| fn conv2d_x_backward( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvOptions<2>, | ||
| ) -> FloatTensor<B> { | ||
| conv::conv2d_x_backward::<B>(x, weight, output_grad, options) | ||
| } | ||
| /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `weight`. | ||
| fn conv2d_weight_backward( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvOptions<2>, | ||
| ) -> FloatTensor<B> { | ||
| conv::conv2d_weight_backward::<B>(x, weight, output_grad, options) | ||
| } | ||
| /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `bias`. | ||
| fn conv2d_bias_backward( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| bias: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| ) -> FloatTensor<B> { | ||
| conv::conv2d_bias_backward::<B>(x, weight, bias, output_grad) | ||
| } | ||
| /// Two dimensional deformable convolution. | ||
| /// | ||
| /// # Shapes | ||
| /// | ||
| /// x: `[batch_size, channels_in, height, width]`, | ||
| /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`, | ||
| /// bias: `[channels_out]`, | ||
| fn deform_conv2d( | ||
| x: FloatTensor<B>, | ||
| offset: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| mask: Option<FloatTensor<B>>, | ||
| bias: Option<FloatTensor<B>>, | ||
| options: DeformConvOptions<2>, | ||
| ) -> FloatTensor<B>; | ||
| /// Backward pass for the [deform_conv2d](ModuleOps::deform_conv2d) operation. | ||
| fn deform_conv2d_backward( | ||
| x: FloatTensor<B>, | ||
| offset: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| mask: Option<FloatTensor<B>>, | ||
| bias: Option<FloatTensor<B>>, | ||
| output_grad: FloatTensor<B>, | ||
| options: DeformConvOptions<2>, | ||
| ) -> DeformConv2dBackward<B>; | ||
| /// Three dimensional convolution. | ||
| /// | ||
| /// # Shapes | ||
| /// | ||
| /// x: `[batch_size, channels_in, depth, height, width]`, | ||
| /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2, kernel_size_3]`, | ||
| /// bias: `[channels_out]`, | ||
| fn conv3d( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| bias: Option<FloatTensor<B>>, | ||
| options: ConvOptions<3>, | ||
| ) -> FloatTensor<B>; | ||
| /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `x`. | ||
| fn conv3d_x_backward( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvOptions<3>, | ||
| ) -> FloatTensor<B> { | ||
| conv::conv3d_x_backward::<B>(x, weight, output_grad, options) | ||
| } | ||
| /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `weight`. | ||
| fn conv3d_weight_backward( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvOptions<3>, | ||
| ) -> FloatTensor<B> { | ||
| conv::conv3d_weight_backward::<B>(x, weight, output_grad, options) | ||
| } | ||
| /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `bias`. | ||
| fn conv3d_bias_backward( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| bias: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| ) -> FloatTensor<B> { | ||
| conv::conv3d_bias_backward::<B>(x, weight, bias, output_grad) | ||
| } | ||
| /// One dimensional transposed convolution. | ||
| /// | ||
| /// # Shapes | ||
| /// | ||
| /// x: `[batch_size, channels_in, length]`, | ||
| /// weight: `[channels_in, channels_out, length]`, | ||
| /// bias: `[channels_out]`, | ||
| fn conv_transpose1d( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| bias: Option<FloatTensor<B>>, | ||
| options: ConvTransposeOptions<1>, | ||
| ) -> FloatTensor<B> { | ||
| conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options) | ||
| } | ||
| /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `x`. | ||
| fn conv_transpose1d_x_backward( | ||
| weight: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvTransposeOptions<1>, | ||
| ) -> FloatTensor<B> { | ||
| conv::conv_transpose1d_x_backward::<B>(weight, output_grad, options) | ||
| } | ||
| /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `weight`. | ||
| fn conv_transpose1d_weight_backward( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvTransposeOptions<1>, | ||
| ) -> FloatTensor<B> { | ||
| conv::conv_transpose1d_weight_backward::<B>(x, weight, output_grad, options) | ||
| } | ||
| /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `bias`. | ||
| fn conv_transpose1d_bias_backward( | ||
| x: FloatTensor<B>, | ||
| bias: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| ) -> FloatTensor<B> { | ||
| conv::conv_transpose1d_bias_backward::<B>(x, bias, output_grad) | ||
| } | ||
| /// Two dimensional transposed convolution. | ||
| /// | ||
| /// # Shapes | ||
| /// | ||
| /// x: `[batch_size, channels_in, height, width]`, | ||
| /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2]`, | ||
| /// bias: `[channels_out]`, | ||
| fn conv_transpose2d( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| bias: Option<FloatTensor<B>>, | ||
| options: ConvTransposeOptions<2>, | ||
| ) -> FloatTensor<B>; | ||
| /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `x`. | ||
| fn conv_transpose2d_x_backward( | ||
| weight: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvTransposeOptions<2>, | ||
| ) -> FloatTensor<B> { | ||
| conv::conv_transpose2d_x_backward::<B>(weight, output_grad, options) | ||
| } | ||
| /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `weight`. | ||
| fn conv_transpose2d_weight_backward( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvTransposeOptions<2>, | ||
| ) -> FloatTensor<B> { | ||
| conv::conv_transpose2d_weight_backward::<B>(x, weight, output_grad, options) | ||
| } | ||
| /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `bias`. | ||
| fn conv_transpose2d_bias_backward( | ||
| x: FloatTensor<B>, | ||
| bias: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| ) -> FloatTensor<B> { | ||
| conv::conv_transpose2d_bias_backward::<B>(x, bias, output_grad) | ||
| } | ||
| /// Three dimensional transposed convolution. | ||
| /// | ||
| /// # Shapes | ||
| /// | ||
| /// x: `[batch_size, channels_in, height, width]`, | ||
| /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2, kernel_size_3]`, | ||
| /// bias: `[channels_out]`, | ||
| fn conv_transpose3d( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| bias: Option<FloatTensor<B>>, | ||
| options: ConvTransposeOptions<3>, | ||
| ) -> FloatTensor<B>; | ||
| /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `x`. | ||
| fn conv_transpose3d_x_backward( | ||
| weight: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvTransposeOptions<3>, | ||
| ) -> FloatTensor<B> { | ||
| conv::conv_transpose3d_x_backward::<B>(weight, output_grad, options) | ||
| } | ||
| /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `weight`. | ||
| fn conv_transpose3d_weight_backward( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvTransposeOptions<3>, | ||
| ) -> FloatTensor<B> { | ||
| conv::conv_transpose3d_weight_backward::<B>(x, weight, output_grad, options) | ||
| } | ||
| /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `bias`. | ||
| fn conv_transpose3d_bias_backward( | ||
| x: FloatTensor<B>, | ||
| bias: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| ) -> FloatTensor<B> { | ||
| conv::conv_transpose3d_bias_backward::<B>(x, bias, output_grad) | ||
| } | ||
| /// Four-dimensional unfolding. | ||
| /// | ||
| /// # Shapes | ||
| /// | ||
| /// * x: ``[batch_size, channels_in, height, width]``, | ||
| /// * returns: ``[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]``, | ||
| fn unfold4d( | ||
| x: FloatTensor<B>, | ||
| kernel_size: [usize; 2], | ||
| options: UnfoldOptions, | ||
| ) -> FloatTensor<B> { | ||
| if options.padding == [0, 0] && options.dilation == [1, 1] { | ||
| let blocks = B::float_unfold(x, 2, kernel_size[0], options.stride[0]); | ||
| let blocks = B::float_unfold(blocks, 3, kernel_size[1], options.stride[1]); | ||
| // batch, channels, h_blocks, w_blocks, h_kern, w_kern | ||
| let blocks = B::float_permute(blocks, &[0, 1, 4, 5, 2, 3]); | ||
| let shape = &blocks.shape().dims; | ||
| // batch, channels, h_kern, w_kern, h_blocks, w_blocks | ||
| B::float_reshape( | ||
| blocks, | ||
| [ | ||
| shape[0], | ||
| shape[1] * shape[2] * shape[3], | ||
| shape[4] * shape[5], | ||
| ] | ||
| .into(), | ||
| ) | ||
| } else { | ||
| unfold4d_using_conv2d::<B>(x, kernel_size, options) | ||
| } | ||
| } | ||
| /// One dimensional avg pooling. | ||
| /// | ||
| /// # Shapes | ||
| /// | ||
| /// x: [batch_size, channels, length], | ||
| fn avg_pool1d( | ||
| x: FloatTensor<B>, | ||
| kernel_size: usize, | ||
| stride: usize, | ||
| padding: usize, | ||
| count_include_pad: bool, | ||
| ) -> FloatTensor<B> { | ||
| pool::avg_pool1d_from_2d::<B>(x, kernel_size, stride, padding, count_include_pad) | ||
| } | ||
| /// Backward pass for the [avg pooling 1d](ModuleOps::avg_pool1d) operation. | ||
| fn avg_pool1d_backward( | ||
| x: FloatTensor<B>, | ||
| grad: FloatTensor<B>, | ||
| kernel_size: usize, | ||
| stride: usize, | ||
| padding: usize, | ||
| count_include_pad: bool, | ||
| ) -> FloatTensor<B> { | ||
| pool::avg_pool1d_backward_from_2d::<B>( | ||
| x, | ||
| grad, | ||
| kernel_size, | ||
| stride, | ||
| padding, | ||
| count_include_pad, | ||
| ) | ||
| } | ||
| /// Two dimensional avg pooling. | ||
| /// | ||
| /// # Shapes | ||
| /// | ||
| /// x: [batch_size, channels, height, width], | ||
| fn avg_pool2d( | ||
| x: FloatTensor<B>, | ||
| kernel_size: [usize; 2], | ||
| stride: [usize; 2], | ||
| padding: [usize; 2], | ||
| count_include_pad: bool, | ||
| ) -> FloatTensor<B>; | ||
| /// Backward pass for the [avg pooling 2d](ModuleOps::avg_pool2d) operation. | ||
| fn avg_pool2d_backward( | ||
| x: FloatTensor<B>, | ||
| grad: FloatTensor<B>, | ||
| kernel_size: [usize; 2], | ||
| stride: [usize; 2], | ||
| padding: [usize; 2], | ||
| count_include_pad: bool, | ||
| ) -> FloatTensor<B>; | ||
| /// Two dimensional adaptive avg pooling. | ||
| /// | ||
| /// # Shapes | ||
| /// | ||
| /// x: [batch_size, channels, height, width], | ||
| fn adaptive_avg_pool2d(x: FloatTensor<B>, output_size: [usize; 2]) -> FloatTensor<B>; | ||
| /// Backward pass for the [adaptive avg pooling 2d](ModuleOps::adaptive_avg_pool2d) operation. | ||
| fn adaptive_avg_pool2d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B>; | ||
| /// One dimensional adaptive avg pooling. | ||
| /// | ||
| /// # Shapes | ||
| /// | ||
| /// x: [batch_size, channels, length], | ||
| fn adaptive_avg_pool1d(x: FloatTensor<B>, output_size: usize) -> FloatTensor<B> { | ||
| pool::adaptive_avg_pool1d_from_2d::<B>(x, output_size) | ||
| } | ||
| /// Backward pass for the [adaptive avg pooling 1d](ModuleOps::adaptive_avg_pool1d) operation. | ||
| fn adaptive_avg_pool1d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> { | ||
| pool::adaptive_avg_pool1d_backward_from_2d::<B>(x, grad) | ||
| } | ||
| /// One dimensional max pooling. | ||
| /// | ||
| /// # Shapes | ||
| /// | ||
| /// x: [batch_size, channels, length], | ||
| fn max_pool1d( | ||
| x: FloatTensor<B>, | ||
| kernel_size: usize, | ||
| stride: usize, | ||
| padding: usize, | ||
| dilation: usize, | ||
| ) -> FloatTensor<B> { | ||
| pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding, dilation) | ||
| } | ||
| /// One dimensional max pooling with indices. | ||
| /// | ||
| /// # Shapes | ||
| /// | ||
| /// x: [batch_size, channels, height, width], | ||
| fn max_pool1d_with_indices( | ||
| x: FloatTensor<B>, | ||
| kernel_size: usize, | ||
| stride: usize, | ||
| padding: usize, | ||
| dilation: usize, | ||
| ) -> MaxPool1dWithIndices<B> { | ||
| pool::max_pool1d_with_indices_from_2d::<B>(x, kernel_size, stride, padding, dilation) | ||
| } | ||
| /// Backward pass for the [max pooling 1d](ModuleOps::max_pool1d_with_indices) operation. | ||
| fn max_pool1d_with_indices_backward( | ||
| x: FloatTensor<B>, | ||
| kernel_size: usize, | ||
| stride: usize, | ||
| padding: usize, | ||
| dilation: usize, | ||
| output_grad: FloatTensor<B>, | ||
| indices: IntTensor<B>, | ||
| ) -> MaxPool1dBackward<B> { | ||
| pool::max_pool1d_with_indices_backward_from_2d::<B>( | ||
| x, | ||
| kernel_size, | ||
| stride, | ||
| padding, | ||
| dilation, | ||
| output_grad, | ||
| indices, | ||
| ) | ||
| } | ||
| /// Two dimensional max pooling. | ||
| /// | ||
| /// # Shapes | ||
| /// | ||
| /// x: [batch_size, channels, height, width], | ||
| fn max_pool2d( | ||
| x: FloatTensor<B>, | ||
| kernel_size: [usize; 2], | ||
| stride: [usize; 2], | ||
| padding: [usize; 2], | ||
| dilation: [usize; 2], | ||
| ) -> FloatTensor<B>; | ||
| /// Two dimensional max pooling with indices. | ||
| /// | ||
| /// # Shapes | ||
| /// | ||
| /// x: [batch_size, channels, height, width], | ||
| fn max_pool2d_with_indices( | ||
| x: FloatTensor<B>, | ||
| kernel_size: [usize; 2], | ||
| stride: [usize; 2], | ||
| padding: [usize; 2], | ||
| dilation: [usize; 2], | ||
| ) -> MaxPool2dWithIndices<B>; | ||
| /// Backward pass for the [max pooling 2d](ModuleOps::max_pool2d_with_indices) operation. | ||
| fn max_pool2d_with_indices_backward( | ||
| x: FloatTensor<B>, | ||
| kernel_size: [usize; 2], | ||
| stride: [usize; 2], | ||
| padding: [usize; 2], | ||
| dilation: [usize; 2], | ||
| output_grad: FloatTensor<B>, | ||
| indices: IntTensor<B>, | ||
| ) -> MaxPool2dBackward<B>; | ||
| /// Down/up samples the input. | ||
| /// | ||
| /// # Shapes | ||
| /// | ||
| /// x: `[batch_size, channels, height, width]`, | ||
| fn interpolate( | ||
| x: FloatTensor<B>, | ||
| output_size: [usize; 2], | ||
| options: InterpolateOptions, | ||
| ) -> FloatTensor<B>; | ||
| /// Backward pass for the [interpolate](ModuleOps::interpolate) operation. | ||
| fn interpolate_backward( | ||
| x: FloatTensor<B>, | ||
| grad: FloatTensor<B>, | ||
| output_size: [usize; 2], | ||
| options: InterpolateOptions, | ||
| ) -> FloatTensor<B>; | ||
| } | ||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
| #[test] | ||
| #[should_panic = "stride must be non-zero"] | ||
| fn conv_options_stride_zero() { | ||
| let _opt = ConvOptions::new([0, 1], [0, 0], [1, 1], 1); | ||
| } | ||
| #[test] | ||
| #[should_panic = "dilation must be non-zero"] | ||
| fn conv_options_dilation_zero() { | ||
| let _opt = ConvOptions::new([1, 1], [0, 0], [0, 0], 1); | ||
| } | ||
| #[test] | ||
| #[should_panic = "groups must be non-zero"] | ||
| fn conv_options_groups_zero() { | ||
| let _opt = ConvOptions::new([1, 1], [0, 0], [1, 1], 0); | ||
| } | ||
| #[test] | ||
| #[should_panic = "stride must be non-zero"] | ||
| fn conv_transpose_options_stride_zero() { | ||
| let _opt = ConvTransposeOptions::new([0, 1], [0, 0], [0, 0], [1, 1], 1); | ||
| } | ||
| #[test] | ||
| #[should_panic = "dilation must be non-zero"] | ||
| fn conv_transpose_options_dilation_zero() { | ||
| let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [0, 0], 1); | ||
| } | ||
| #[test] | ||
| #[should_panic = "groups must be non-zero"] | ||
| fn conv_transpose_options_groups_zero() { | ||
| let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [1, 1], 0); | ||
| } | ||
| #[test] | ||
| #[should_panic = "stride must be non-zero"] | ||
| fn deform_conv_options_stride_zero() { | ||
| let _opt = DeformConvOptions::new([0, 1], [0, 0], [1, 1], 1, 1); | ||
| } | ||
| #[test] | ||
| #[should_panic = "dilation must be non-zero"] | ||
| fn deform_conv_options_dilation_zero() { | ||
| let _opt = DeformConvOptions::new([1, 1], [0, 0], [0, 0], 1, 1); | ||
| } | ||
| #[test] | ||
| #[should_panic = "weight groups must be non-zero"] | ||
| fn deform_conv_options_weights_groups_zero() { | ||
| let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 0, 1); | ||
| } | ||
| #[test] | ||
| #[should_panic = "offset groups must be non-zero"] | ||
| fn deform_conv_options_offset_groups_zero() { | ||
| let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 1, 0); | ||
| } | ||
| #[test] | ||
| #[should_panic = "stride must be non-zero"] | ||
| fn unfold_options_stride_zero() { | ||
| let _opt = UnfoldOptions::new([0, 1], [0, 0], [1, 1]); | ||
| } | ||
| #[test] | ||
| #[should_panic = "dilation must be non-zero"] | ||
| fn unfold_options_dilation_zero() { | ||
| let _opt = UnfoldOptions::new([1, 1], [0, 0], [0, 0]); | ||
| } | ||
| } |
| use crate::{BasicOps, Slice, TensorKind, TensorMetadata, backend::Backend}; | ||
| use alloc::vec::Vec; | ||
| pub(crate) fn cat_with_slice_assign<B: Backend, K: TensorKind<B> + BasicOps<B>>( | ||
| tensors: Vec<K::Primitive>, | ||
| dim: usize, | ||
| ) -> K::Primitive { | ||
| let first_tensor = tensors.first().expect("Tensors should not be empty"); | ||
| let mut shape = first_tensor.shape(); | ||
| let device = K::device(first_tensor); | ||
| let dtype = first_tensor.dtype(); | ||
| let output_dim_length: usize = tensors.iter().map(|tensor| tensor.shape().dims[dim]).sum(); | ||
| shape[dim] = output_dim_length; | ||
| let mut tensor_output = K::empty(shape.clone(), &device, dtype); | ||
| let indices_select_all = shape.iter().map(|d| 0..*d).collect::<Vec<_>>(); | ||
| let mut output_index = 0; | ||
| for tensor in tensors { | ||
| let mut indices = indices_select_all.clone(); | ||
| let tensor_dim_length = tensor.shape().dims[dim]; | ||
| indices[dim] = output_index..output_index + tensor_dim_length; | ||
| output_index += tensor_dim_length; | ||
| // Convert ranges to Slice | ||
| let slices: Vec<Slice> = indices | ||
| .iter() | ||
| .map(|r| Slice::new(r.start as isize, Some(r.end as isize), 1)) | ||
| .collect(); | ||
| tensor_output = K::slice_assign(tensor_output, &slices, tensor); | ||
| } | ||
| tensor_output | ||
| } |
| #![allow(clippy::single_range_in_vec_init)] | ||
| use super::{ConvOptions, ConvTransposeOptions}; | ||
| use crate::{Shape, ShapeError, Slice, TensorMetadata, backend::Backend, ops::FloatTensor}; | ||
| use alloc::{vec, vec::Vec}; | ||
| #[cfg(not(feature = "std"))] | ||
| #[allow(unused_imports)] | ||
| use num_traits::Float as _; | ||
| /// Calculate the expected output shape `[batch_size, channels_out, spatial_dims, ..]` for a pooling operation. | ||
| pub fn calculate_pool_output_shape<const N: usize>( | ||
| in_shape: &Shape, | ||
| kernel_size: &[usize; N], | ||
| stride: &[usize; N], | ||
| padding: &[usize; N], | ||
| dilation: &[usize; N], | ||
| ) -> Result<Shape, ShapeError> { | ||
| if in_shape.rank() != N + 2 { | ||
| return Err(ShapeError::RankMismatch { | ||
| left: in_shape.rank(), | ||
| right: N + 2, | ||
| }); | ||
| } | ||
| let mut out_shape = in_shape.clone(); | ||
| // Spatial dims | ||
| for (i, size_i) in out_shape[2..].iter_mut().enumerate() { | ||
| *size_i = | ||
| calculate_pool_output_size(kernel_size[i], stride[i], padding[i], dilation[i], *size_i); | ||
| } | ||
| Ok(out_shape) | ||
| } | ||
| /// Calculate the expected output shape `[batch_size, channels_out, spatial_dims, ..]` for a convolution. | ||
| pub fn calculate_conv_output_shape<const N: usize>( | ||
| in_shape: &Shape, | ||
| weight_shape: &Shape, | ||
| stride: &[usize; N], | ||
| padding: &[usize; N], | ||
| dilation: &[usize; N], | ||
| ) -> Result<Shape, ShapeError> { | ||
| if weight_shape.rank() != N + 2 { | ||
| return Err(ShapeError::RankMismatch { | ||
| left: weight_shape.rank(), | ||
| right: N + 2, | ||
| }); | ||
| } | ||
| if in_shape.rank() != N + 2 { | ||
| return Err(ShapeError::RankMismatch { | ||
| left: in_shape.rank(), | ||
| right: N + 2, | ||
| }); | ||
| } | ||
| let kernel_size = &weight_shape[2..]; | ||
| let mut out_shape = in_shape.clone(); | ||
| // Spatial dims | ||
| for (i, size_i) in out_shape[2..].iter_mut().enumerate() { | ||
| *size_i = | ||
| calculate_conv_output_size(kernel_size[i], stride[i], padding[i], dilation[i], *size_i); | ||
| } | ||
| // Output channels | ||
| out_shape[1] = weight_shape[0]; | ||
| Ok(out_shape) | ||
| } | ||
| /// Calculate the expected output shape `[batch_size, channels_out, spatial_dims, ..]` for a transposed convolution. | ||
| pub fn calculate_conv_transpose_output_shape<const N: usize>( | ||
| in_shape: &Shape, | ||
| weight_shape: &Shape, | ||
| stride: &[usize; N], | ||
| padding: &[usize; N], | ||
| padding_out: &[usize; N], | ||
| dilation: &[usize; N], | ||
| groups: usize, | ||
| ) -> Result<Shape, ShapeError> { | ||
| if weight_shape.rank() != N + 2 { | ||
| return Err(ShapeError::RankMismatch { | ||
| left: weight_shape.rank(), | ||
| right: N + 2, | ||
| }); | ||
| } | ||
| if in_shape.rank() != N + 2 { | ||
| return Err(ShapeError::RankMismatch { | ||
| left: in_shape.rank(), | ||
| right: N + 2, | ||
| }); | ||
| } | ||
| let kernel_size = &weight_shape[2..]; | ||
| let mut out_shape = in_shape.clone(); | ||
| // Spatial dims | ||
| for (i, size_i) in out_shape[2..].iter_mut().enumerate() { | ||
| *size_i = calculate_conv_transpose_output_size( | ||
| kernel_size[i], | ||
| stride[i], | ||
| padding[i], | ||
| padding_out[i], | ||
| dilation[i], | ||
| *size_i, | ||
| ); | ||
| } | ||
| // Output channels | ||
| out_shape[1] = weight_shape[1] * groups; | ||
| Ok(out_shape) | ||
| } | ||
| /// Calculate the expected padding size required when applying a convolution. | ||
| pub fn calculate_conv_padding( | ||
| kernel_size: usize, | ||
| stride: usize, | ||
| size_in: usize, | ||
| size_out: usize, | ||
| ) -> usize { | ||
| let kernel_size = kernel_size as f32; | ||
| let stride = stride as f32; | ||
| let size_in = size_in as f32; | ||
| let size_out = size_out as f32; | ||
| let padding = stride * (size_out - 1.) - size_in + kernel_size; | ||
| let padding = (padding / 2.).ceil(); | ||
| padding as usize | ||
| } | ||
| /// Calculate the expected output size when doing a convolution operation. | ||
| pub fn calculate_conv_output_size( | ||
| kernel_size: usize, | ||
| stride: usize, | ||
| padding: usize, | ||
| dilation: usize, | ||
| size_in: usize, | ||
| ) -> usize { | ||
| (size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1 | ||
| } | ||
| /// Calculate the expected output sizes when doing a convolution operation. | ||
| pub fn calculate_conv_output_sizes( | ||
| kernel_size: &[usize], | ||
| stride: &[usize], | ||
| padding: &[usize], | ||
| dilation: &[usize], | ||
| size_in: &[usize], | ||
| ) -> Vec<usize> { | ||
| size_in | ||
| .iter() | ||
| .enumerate() | ||
| .map(|(i, size_in)| { | ||
| calculate_conv_output_size(kernel_size[i], stride[i], padding[i], dilation[i], *size_in) | ||
| }) | ||
| .collect() | ||
| } | ||
| /// Calculate the expected output size when doing a transposed convolution operation. | ||
| pub fn calculate_conv_transpose_output_size( | ||
| kernel_size: usize, | ||
| stride: usize, | ||
| padding: usize, | ||
| padding_out: usize, | ||
| dilation: usize, | ||
| size_in: usize, | ||
| ) -> usize { | ||
| (size_in - 1) * stride + (dilation * (kernel_size - 1) + 1) + padding_out - 2 * padding | ||
| } | ||
| /// Calculate the expected output size when doing a pooling operation. | ||
| pub fn calculate_pool_output_size( | ||
| kernel_size: usize, | ||
| stride: usize, | ||
| padding: usize, | ||
| dilation: usize, | ||
| size_in: usize, | ||
| ) -> usize { | ||
| ((size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride) + 1 | ||
| } | ||
| /// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `x`. | ||
| pub(crate) fn conv1d_x_backward<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvOptions<1>, | ||
| ) -> FloatTensor<B> { | ||
| let weight_shape = weight.shape(); | ||
| let [_batch_size, _, length_in] = x.shape().dims(); | ||
| let [_batch_size, _channels_out, length_out] = output_grad.shape().dims(); | ||
| let [_, _, kernel_size] = weight_shape.dims(); | ||
| let padding_out = calculate_padding_out( | ||
| kernel_size, | ||
| options.stride[0], | ||
| options.padding[0], | ||
| options.dilation[0], | ||
| length_in, | ||
| length_out, | ||
| ); | ||
| B::conv_transpose1d( | ||
| output_grad, | ||
| weight, | ||
| None, | ||
| ConvTransposeOptions::new( | ||
| options.stride, | ||
| options.padding, | ||
| [padding_out], | ||
| options.dilation, | ||
| options.groups, | ||
| ), | ||
| ) | ||
| } | ||
| /// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `weight`. | ||
| pub(crate) fn conv1d_weight_backward<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvOptions<1>, | ||
| ) -> FloatTensor<B> { | ||
| let weight_dtype = weight.dtype(); | ||
| let weight_shape = weight.shape(); | ||
| let weight_device = B::float_device(&weight); | ||
| match options.groups == 1 { | ||
| true => conv1d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options), | ||
| false => conv1d_weight_grad_groups::<B>( | ||
| x, | ||
| B::float_zeros(weight_shape, &weight_device, weight_dtype.into()), | ||
| output_grad, | ||
| options, | ||
| ), | ||
| } | ||
| } | ||
| /// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `bias`. | ||
| pub(crate) fn conv1d_bias_backward<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| bias: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| ) -> FloatTensor<B> { | ||
| let [batch_size, _, _length_in] = x.shape().dims(); | ||
| let [_batch_size, channels_out, length_out] = output_grad.shape().dims(); | ||
| let grad = B::float_swap_dims(output_grad, 0, 1); | ||
| let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out])); | ||
| let grad = B::float_sum_dim(grad, 1); | ||
| B::float_reshape(grad, bias.shape()) | ||
| } | ||
| /// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `x`. | ||
| pub(crate) fn conv2d_x_backward<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvOptions<2>, | ||
| ) -> FloatTensor<B> { | ||
| let weight_shape = weight.shape(); | ||
| let [_batch_size, _channels_in, height_in, width_in] = x.shape().dims(); | ||
| let [_, _, height_out, width_out] = output_grad.shape().dims(); | ||
| let [_channels_out, _, kernel_size_1, kernel_size_2] = weight_shape.dims(); | ||
| let padding_1_out = calculate_padding_out( | ||
| kernel_size_1, | ||
| options.stride[0], | ||
| options.padding[0], | ||
| options.dilation[0], | ||
| height_in, | ||
| height_out, | ||
| ); | ||
| let padding_2_out = calculate_padding_out( | ||
| kernel_size_2, | ||
| options.stride[1], | ||
| options.padding[1], | ||
| options.dilation[1], | ||
| width_in, | ||
| width_out, | ||
| ); | ||
| B::conv_transpose2d( | ||
| output_grad, | ||
| weight, | ||
| None, | ||
| ConvTransposeOptions::new( | ||
| options.stride, | ||
| options.padding, | ||
| [padding_1_out, padding_2_out], | ||
| options.dilation, | ||
| options.groups, | ||
| ), | ||
| ) | ||
| } | ||
| /// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `weight`. | ||
| pub(crate) fn conv2d_weight_backward<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvOptions<2>, | ||
| ) -> FloatTensor<B> { | ||
| let weight_dtype = weight.dtype(); | ||
| let weight_shape = weight.shape(); | ||
| let weight_device = B::float_device(&weight); | ||
| match options.groups == 1 { | ||
| true => conv2d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options), | ||
| false => conv2d_weight_grad_groups::<B>( | ||
| x, | ||
| B::float_zeros(weight_shape, &weight_device, weight_dtype.into()), | ||
| output_grad, | ||
| options, | ||
| ), | ||
| } | ||
| } | ||
| /// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `bias`. | ||
| pub(crate) fn conv2d_bias_backward<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| bias: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| ) -> FloatTensor<B> { | ||
| let weight_shape = weight.shape(); | ||
| let [batch_size, _channels_in, _height_in, _width_in] = x.shape().dims(); | ||
| let [_, _, height_out, width_out] = output_grad.shape().dims(); | ||
| let [channels_out, _, _kernel_size_1, _kernel_size_2] = weight_shape.dims(); | ||
| let grad = B::float_swap_dims(output_grad, 0, 1); | ||
| let grad = B::float_reshape( | ||
| grad, | ||
| Shape::new([channels_out, batch_size * height_out * width_out]), | ||
| ); | ||
| let grad = B::float_sum_dim(grad, 1); | ||
| B::float_reshape(grad, bias.shape()) | ||
| } | ||
| /// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `x`. | ||
| pub(crate) fn conv3d_x_backward<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvOptions<3>, | ||
| ) -> FloatTensor<B> { | ||
| let weight_shape = weight.shape(); | ||
| let [_batch_size, _channels_in, depth_in, height_in, width_in] = x.shape().dims(); | ||
| let [_, _, depth_out, height_out, width_out] = output_grad.shape().dims(); | ||
| let [ | ||
| _channels_out, | ||
| _, | ||
| kernel_size_1, | ||
| kernel_size_2, | ||
| kernel_size_3, | ||
| ] = weight_shape.dims(); | ||
| let padding_1_out = calculate_padding_out( | ||
| kernel_size_1, | ||
| options.stride[0], | ||
| options.padding[0], | ||
| options.dilation[0], | ||
| depth_in, | ||
| depth_out, | ||
| ); | ||
| let padding_2_out = calculate_padding_out( | ||
| kernel_size_2, | ||
| options.stride[1], | ||
| options.padding[1], | ||
| options.dilation[1], | ||
| height_in, | ||
| height_out, | ||
| ); | ||
| let padding_3_out = calculate_padding_out( | ||
| kernel_size_3, | ||
| options.stride[2], | ||
| options.padding[2], | ||
| options.dilation[2], | ||
| width_in, | ||
| width_out, | ||
| ); | ||
| B::conv_transpose3d( | ||
| output_grad, | ||
| weight, | ||
| None, | ||
| ConvTransposeOptions::new( | ||
| options.stride, | ||
| options.padding, | ||
| [padding_1_out, padding_2_out, padding_3_out], | ||
| options.dilation, | ||
| options.groups, | ||
| ), | ||
| ) | ||
| } | ||
| /// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `weight`. | ||
| pub(crate) fn conv3d_weight_backward<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvOptions<3>, | ||
| ) -> FloatTensor<B> { | ||
| let weight_dtype = weight.dtype(); | ||
| let weight_shape = weight.shape(); | ||
| let weight_device = B::float_device(&weight); | ||
| match options.groups == 1 { | ||
| true => conv3d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options), | ||
| false => conv3d_weight_grad_groups::<B>( | ||
| x, | ||
| B::float_zeros(weight_shape, &weight_device, weight_dtype.into()), | ||
| output_grad, | ||
| options, | ||
| ), | ||
| } | ||
| } | ||
| /// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `bias`. | ||
| pub(crate) fn conv3d_bias_backward<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| bias: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| ) -> FloatTensor<B> { | ||
| let weight_shape = weight.shape(); | ||
| let [batch_size, _channels_in, _depth_in, _height_in, _width_in] = x.shape().dims(); | ||
| let [_, _, depth_out, height_out, width_out] = output_grad.shape().dims(); | ||
| let [ | ||
| channels_out, | ||
| _, | ||
| _kernel_size_1, | ||
| _kernel_size_2, | ||
| _kernel_size_3, | ||
| ] = weight_shape.dims(); | ||
| let grad = B::float_swap_dims(output_grad, 0, 1); | ||
| let grad = B::float_reshape( | ||
| grad, | ||
| Shape::new([ | ||
| channels_out, | ||
| batch_size * depth_out * height_out * width_out, | ||
| ]), | ||
| ); | ||
| let grad = B::float_sum_dim(grad, 1); | ||
| B::float_reshape(grad, bias.shape()) | ||
| } | ||
| /// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `x`. | ||
| pub(crate) fn conv_transpose1d_x_backward<B: Backend>( | ||
| weight: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvTransposeOptions<1>, | ||
| ) -> FloatTensor<B> { | ||
| B::conv1d( | ||
| output_grad, | ||
| weight, | ||
| None, | ||
| ConvOptions::new( | ||
| options.stride, | ||
| options.padding, | ||
| options.dilation, | ||
| options.groups, | ||
| ), | ||
| ) | ||
| } | ||
| /// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `weight`. | ||
| pub(crate) fn conv_transpose1d_weight_backward<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvTransposeOptions<1>, | ||
| ) -> FloatTensor<B> { | ||
| let weight_dtype = weight.dtype(); | ||
| let weight_shape = weight.shape(); | ||
| let weight_device = B::float_device(&weight); | ||
| match options.groups == 1 { | ||
| true => conv_transpose1d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options), | ||
| false => conv_transpose1d_weight_grad_groups::<B>( | ||
| x, | ||
| B::float_zeros(weight_shape, &weight_device, weight_dtype.into()), | ||
| output_grad, | ||
| options, | ||
| ), | ||
| } | ||
| } | ||
| /// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `bias`. | ||
| pub(crate) fn conv_transpose1d_bias_backward<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| bias: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| ) -> FloatTensor<B> { | ||
| let [batch_size, _channels_in, _] = x.shape().dims(); | ||
| let [_, channels_out, length_out] = output_grad.shape().dims(); | ||
| let grad = B::float_swap_dims(output_grad, 0, 1); | ||
| let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out])); | ||
| let grad = B::float_sum_dim(grad, 1); | ||
| B::float_reshape(grad, bias.shape()) | ||
| } | ||
| /// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `x`. | ||
| pub(crate) fn conv_transpose2d_x_backward<B: Backend>( | ||
| weight: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvTransposeOptions<2>, | ||
| ) -> FloatTensor<B> { | ||
| B::conv2d( | ||
| output_grad, | ||
| weight, | ||
| None, | ||
| ConvOptions::new( | ||
| options.stride, | ||
| options.padding, | ||
| options.dilation, | ||
| options.groups, | ||
| ), | ||
| ) | ||
| } | ||
| /// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `weight`. | ||
| pub(crate) fn conv_transpose2d_weight_backward<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvTransposeOptions<2>, | ||
| ) -> FloatTensor<B> { | ||
| let weight_dtype = weight.dtype(); | ||
| let weight_shape = weight.shape(); | ||
| let weight_device = B::float_device(&weight); | ||
| match options.groups == 1 { | ||
| true => conv_transpose2d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options), | ||
| false => conv_transpose2d_weight_grad_groups::<B>( | ||
| x, | ||
| B::float_zeros(weight_shape, &weight_device, weight_dtype.into()), | ||
| output_grad, | ||
| options, | ||
| ), | ||
| } | ||
| } | ||
| /// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `bias`. | ||
| pub(crate) fn conv_transpose2d_bias_backward<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| bias: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| ) -> FloatTensor<B> { | ||
| let [batch_size, _channels_in, _, _] = x.shape().dims(); | ||
| let [_, channels_out, height_out, width_out] = output_grad.shape().dims(); | ||
| let grad = B::float_swap_dims(output_grad, 0, 1); | ||
| let grad = B::float_reshape( | ||
| grad, | ||
| Shape::new([channels_out, batch_size * height_out * width_out]), | ||
| ); | ||
| let grad = B::float_sum_dim(grad, 1); | ||
| B::float_reshape(grad, bias.shape()) | ||
| } | ||
| /// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `x`. | ||
| pub(crate) fn conv_transpose3d_x_backward<B: Backend>( | ||
| weight: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvTransposeOptions<3>, | ||
| ) -> FloatTensor<B> { | ||
| B::conv3d( | ||
| output_grad, | ||
| weight, | ||
| None, | ||
| ConvOptions::new( | ||
| options.stride, | ||
| options.padding, | ||
| options.dilation, | ||
| options.groups, | ||
| ), | ||
| ) | ||
| } | ||
| /// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `weight`. | ||
| pub(crate) fn conv_transpose3d_weight_backward<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvTransposeOptions<3>, | ||
| ) -> FloatTensor<B> { | ||
| let weight_dtype = weight.dtype(); | ||
| let weight_shape = weight.shape(); | ||
| let weight_device = B::float_device(&weight); | ||
| match options.groups == 1 { | ||
| true => conv_transpose3d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options), | ||
| false => conv_transpose3d_weight_grad_groups::<B>( | ||
| x, | ||
| B::float_zeros(weight_shape, &weight_device, weight_dtype.into()), | ||
| output_grad, | ||
| options, | ||
| ), | ||
| } | ||
| } | ||
| /// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `bias`. | ||
| pub(crate) fn conv_transpose3d_bias_backward<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| bias: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| ) -> FloatTensor<B> { | ||
| let [batch_size, _channels_in, _, _, _] = x.shape().dims(); | ||
| let [_, channels_out, depth_out, height_out, width_out] = output_grad.shape().dims(); | ||
| let grad = B::float_swap_dims(output_grad, 0, 1); | ||
| let grad = B::float_reshape( | ||
| grad, | ||
| Shape::new([ | ||
| channels_out, | ||
| batch_size * depth_out * height_out * width_out, | ||
| ]), | ||
| ); | ||
| let grad = B::float_sum_dim(grad, 1); | ||
| B::float_reshape(grad, bias.shape()) | ||
| } | ||
| /// Execute a 1D convolution using a 2D convolution. | ||
| pub(crate) fn conv1d_from_conv2d<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| bias: Option<FloatTensor<B>>, | ||
| options: ConvOptions<1>, | ||
| ) -> FloatTensor<B> { | ||
| let [channels_out, _channels_in, kernel_size] = weight.shape().dims(); | ||
| let [batch_size, channels_in, length_in] = x.shape().dims(); | ||
| let weight = B::float_reshape( | ||
| weight, | ||
| Shape::new([channels_out, channels_in / options.groups, kernel_size, 1]), | ||
| ); | ||
| let x = B::float_reshape(x, Shape::new([batch_size, channels_in, length_in, 1])); | ||
| let tensor = B::conv2d( | ||
| x, | ||
| weight, | ||
| bias, | ||
| ConvOptions::new( | ||
| [options.stride[0], 1], | ||
| [options.padding[0], 0], | ||
| [options.dilation[0], 1], | ||
| options.groups, | ||
| ), | ||
| ); | ||
| let [batch_size, channels_out, height_out, _weight_out] = tensor.shape().dims(); | ||
| B::float_reshape(tensor, Shape::from([batch_size, channels_out, height_out])) | ||
| } | ||
| /// Execute a 1D transposed convolution using a 2D transposed convolution. | ||
| pub(crate) fn conv_transpose1d_from_conv_transpose2d<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| weight: FloatTensor<B>, | ||
| bias: Option<FloatTensor<B>>, | ||
| options: ConvTransposeOptions<1>, | ||
| ) -> FloatTensor<B> { | ||
| let [channels_in, channels_out, kernel_size] = weight.shape().dims(); | ||
| let [batch_size, _channels_in, length_in] = x.shape().dims(); | ||
| let weight = B::float_reshape( | ||
| weight, | ||
| Shape::new([channels_in, channels_out, kernel_size, 1]), | ||
| ); | ||
| let x = B::float_reshape(x, Shape::new([batch_size, channels_in, length_in, 1])); | ||
| let tensor = B::conv_transpose2d( | ||
| x, | ||
| weight, | ||
| bias, | ||
| ConvTransposeOptions::new( | ||
| [options.stride[0], 1], | ||
| [options.padding[0], 0], | ||
| [options.padding_out[0], 0], | ||
| [options.dilation[0], 1], | ||
| options.groups, | ||
| ), | ||
| ); | ||
| let [batch_size, channels_out, height_out, _weight_out] = tensor.shape().dims(); | ||
| B::float_reshape(tensor, Shape::from([batch_size, channels_out, height_out])) | ||
| } | ||
| fn conv1d_weight_grad_no_groups<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| weight_shape: Shape, | ||
| options: ConvOptions<1>, | ||
| ) -> FloatTensor<B> { | ||
| let x_swapped = B::float_swap_dims(x, 0, 1); | ||
| let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1); | ||
| let weight_grad_swapped = B::conv1d( | ||
| x_swapped, | ||
| output_grad_swapped, | ||
| None, | ||
| ConvOptions::new(options.dilation, options.padding, options.stride, 1), | ||
| ); | ||
| let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1); | ||
| if weight_grad.shape() != weight_shape { | ||
| let slices = vec![ | ||
| Slice::from(0..weight_shape[0]), | ||
| Slice::from(0..weight_shape[1]), | ||
| Slice::from(0..weight_shape[2]), | ||
| ]; | ||
| weight_grad = B::float_slice(weight_grad, &slices); | ||
| } | ||
| weight_grad | ||
| } | ||
| fn conv2d_weight_grad_no_groups<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| weight_shape: Shape, | ||
| options: ConvOptions<2>, | ||
| ) -> FloatTensor<B> { | ||
| let x_swapped = B::float_swap_dims(x, 0, 1); | ||
| let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1); | ||
| let weight_grad_swapped = B::conv2d( | ||
| x_swapped, | ||
| output_grad_swapped, | ||
| None, | ||
| ConvOptions::new(options.dilation, options.padding, options.stride, 1), | ||
| ); | ||
| let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1); | ||
| if weight_grad.shape() != weight_shape { | ||
| let slices = vec![ | ||
| Slice::from(0..weight_shape[0]), | ||
| Slice::from(0..weight_shape[1]), | ||
| Slice::from(0..weight_shape[2]), | ||
| Slice::from(0..weight_shape[3]), | ||
| ]; | ||
| weight_grad = B::float_slice(weight_grad, &slices); | ||
| } | ||
| weight_grad | ||
| } | ||
| fn conv3d_weight_grad_no_groups<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| weight_shape: Shape, | ||
| options: ConvOptions<3>, | ||
| ) -> FloatTensor<B> { | ||
| let x_swapped = B::float_swap_dims(x, 0, 1); | ||
| let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1); | ||
| let weight_grad_swapped = B::conv3d( | ||
| x_swapped, | ||
| output_grad_swapped, | ||
| None, | ||
| ConvOptions::new(options.dilation, options.padding, options.stride, 1), | ||
| ); | ||
| let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1); | ||
| if weight_grad.shape() != weight_shape { | ||
| let slices = vec![ | ||
| Slice::from(0..weight_shape[0]), | ||
| Slice::from(0..weight_shape[1]), | ||
| Slice::from(0..weight_shape[2]), | ||
| Slice::from(0..weight_shape[3]), | ||
| Slice::from(0..weight_shape[4]), | ||
| ]; | ||
| weight_grad = B::float_slice(weight_grad, &slices); | ||
| } | ||
| weight_grad | ||
| } | ||
| fn conv1d_weight_grad_groups<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| mut weight_grad: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvOptions<1>, | ||
| ) -> FloatTensor<B> { | ||
| let [channels_out, increment_ci, kernel_size] = weight_grad.shape().dims(); | ||
| let increment_co = channels_out / options.groups; | ||
| let x_swapped = B::float_swap_dims(x, 0, 1); | ||
| let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1); | ||
| for g in 0..options.groups { | ||
| let start_idx_ci = g * increment_ci; | ||
| let end_idx_ci = (g + 1) * increment_ci; | ||
| let start_idx_co = g * increment_co; | ||
| let end_idx_co = (g + 1) * increment_co; | ||
| let x_slice = vec![Slice::new( | ||
| start_idx_ci as isize, | ||
| Some(end_idx_ci as isize), | ||
| 1, | ||
| )]; | ||
| let x = B::float_slice(x_swapped.clone(), &x_slice); | ||
| let grad_slice = vec![Slice::new( | ||
| start_idx_co as isize, | ||
| Some(end_idx_co as isize), | ||
| 1, | ||
| )]; | ||
| let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice); | ||
| let mut weight_grad_tmp = B::conv1d( | ||
| x, | ||
| grad, | ||
| None, | ||
| ConvOptions::new(options.dilation, options.padding, options.stride, 1), | ||
| ); | ||
| weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1); | ||
| weight_grad = B::float_slice_assign( | ||
| weight_grad, | ||
| &[ | ||
| Slice::from(start_idx_co..end_idx_co), | ||
| Slice::from(0..increment_ci), | ||
| Slice::from(0..kernel_size), | ||
| ], | ||
| weight_grad_tmp, | ||
| ); | ||
| } | ||
| weight_grad | ||
| } | ||
| fn conv2d_weight_grad_groups<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| mut weight_grad: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvOptions<2>, | ||
| ) -> FloatTensor<B> { | ||
| let [channels_out, increment_ci, kernel_size_1, kernel_size_2] = weight_grad.shape().dims(); | ||
| let increment_co = channels_out / options.groups; | ||
| let x_swapped = B::float_swap_dims(x, 0, 1); | ||
| let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1); | ||
| for g in 0..options.groups { | ||
| let start_idx_ci = g * increment_ci; | ||
| let end_idx_ci = (g + 1) * increment_ci; | ||
| let start_idx_co = g * increment_co; | ||
| let end_idx_co = (g + 1) * increment_co; | ||
| let x_slice = vec![Slice::new( | ||
| start_idx_ci as isize, | ||
| Some(end_idx_ci as isize), | ||
| 1, | ||
| )]; | ||
| let x = B::float_slice(x_swapped.clone(), &x_slice); | ||
| let grad_slice = vec![Slice::new( | ||
| start_idx_co as isize, | ||
| Some(end_idx_co as isize), | ||
| 1, | ||
| )]; | ||
| let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice); | ||
| let mut weight_grad_tmp = B::conv2d( | ||
| x, | ||
| grad, | ||
| None, | ||
| ConvOptions::new(options.dilation, options.padding, options.stride, 1), | ||
| ); | ||
| weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1); | ||
| let [_, _, kernel_size_1_tmp, kernel_size_2_tmp] = weight_grad_tmp.shape().dims(); | ||
| if kernel_size_1_tmp != kernel_size_1 || kernel_size_2_tmp != kernel_size_2 { | ||
| let slices = vec![ | ||
| Slice::from(0..increment_co), | ||
| Slice::from(0..increment_ci), | ||
| Slice::from(0..kernel_size_1), | ||
| Slice::from(0..kernel_size_2), | ||
| ]; | ||
| weight_grad_tmp = B::float_slice(weight_grad_tmp, &slices); | ||
| } | ||
| weight_grad = B::float_slice_assign( | ||
| weight_grad, | ||
| &[ | ||
| Slice::from(start_idx_co..end_idx_co), | ||
| Slice::from(0..increment_ci), | ||
| Slice::from(0..kernel_size_1), | ||
| Slice::from(0..kernel_size_2), | ||
| ], | ||
| weight_grad_tmp, | ||
| ); | ||
| } | ||
| weight_grad | ||
| } | ||
| fn conv3d_weight_grad_groups<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| mut weight_grad: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvOptions<3>, | ||
| ) -> FloatTensor<B> { | ||
| let [ | ||
| channels_out, | ||
| increment_ci, | ||
| kernel_size_1, | ||
| kernel_size_2, | ||
| kernel_size_3, | ||
| ] = weight_grad.shape().dims(); | ||
| let increment_co = channels_out / options.groups; | ||
| let x_swapped = B::float_swap_dims(x, 0, 1); | ||
| let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1); | ||
| for g in 0..options.groups { | ||
| let start_idx_ci = g * increment_ci; | ||
| let end_idx_ci = (g + 1) * increment_ci; | ||
| let start_idx_co = g * increment_co; | ||
| let end_idx_co = (g + 1) * increment_co; | ||
| let x_slice = vec![Slice::new( | ||
| start_idx_ci as isize, | ||
| Some(end_idx_ci as isize), | ||
| 1, | ||
| )]; | ||
| let x = B::float_slice(x_swapped.clone(), &x_slice); | ||
| let grad_slice = vec![Slice::new( | ||
| start_idx_co as isize, | ||
| Some(end_idx_co as isize), | ||
| 1, | ||
| )]; | ||
| let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice); | ||
| let mut weight_grad_tmp = B::conv3d( | ||
| x, | ||
| grad, | ||
| None, | ||
| ConvOptions::new(options.dilation, options.padding, options.stride, 1), | ||
| ); | ||
| weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1); | ||
| let [ | ||
| _, | ||
| _, | ||
| kernel_size_1_tmp, | ||
| kernel_size_2_tmp, | ||
| kernel_size_3_tmp, | ||
| ] = weight_grad_tmp.shape().dims(); | ||
| if kernel_size_1_tmp != kernel_size_1 | ||
| || kernel_size_2_tmp != kernel_size_2 | ||
| || kernel_size_3_tmp != kernel_size_3 | ||
| { | ||
| let slices = vec![ | ||
| Slice::from(0..increment_co), | ||
| Slice::from(0..increment_ci), | ||
| Slice::from(0..kernel_size_1), | ||
| Slice::from(0..kernel_size_2), | ||
| Slice::from(0..kernel_size_3), | ||
| ]; | ||
| weight_grad_tmp = B::float_slice(weight_grad_tmp, &slices); | ||
| } | ||
| weight_grad = B::float_slice_assign( | ||
| weight_grad, | ||
| &[ | ||
| Slice::from(start_idx_co..end_idx_co), | ||
| Slice::from(0..increment_ci), | ||
| Slice::from(0..kernel_size_1), | ||
| Slice::from(0..kernel_size_2), | ||
| Slice::from(0..kernel_size_3), | ||
| ], | ||
| weight_grad_tmp, | ||
| ); | ||
| } | ||
| weight_grad | ||
| } | ||
| fn conv_transpose1d_weight_grad_no_groups<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| weight_shape: Shape, | ||
| options: ConvTransposeOptions<1>, | ||
| ) -> FloatTensor<B> { | ||
| let x_swapped = B::float_swap_dims(x, 0, 1); | ||
| let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1); | ||
| let weight_grad_swapped = B::conv1d( | ||
| output_grad_swapped, | ||
| x_swapped, | ||
| None, | ||
| ConvOptions::new(options.dilation, options.padding, options.stride, 1), | ||
| ); | ||
| let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1); | ||
| let grad_shape = weight_grad.shape(); | ||
| if grad_shape != weight_shape { | ||
| let slices = vec![ | ||
| Slice::from(0..weight_shape[0]), | ||
| Slice::from(0..weight_shape[1]), | ||
| Slice::from(0..weight_shape[2]), | ||
| ]; | ||
| weight_grad = B::float_slice(weight_grad, &slices); | ||
| } | ||
| weight_grad | ||
| } | ||
| fn conv_transpose2d_weight_grad_no_groups<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| weight_shape: Shape, | ||
| options: ConvTransposeOptions<2>, | ||
| ) -> FloatTensor<B> { | ||
| let x_swapped = B::float_swap_dims(x, 0, 1); | ||
| let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1); | ||
| let weight_grad_swapped = B::conv2d( | ||
| output_grad_swapped, | ||
| x_swapped, | ||
| None, | ||
| ConvOptions::new(options.dilation, options.padding, options.stride, 1), | ||
| ); | ||
| let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1); | ||
| let grad_shape = weight_grad.shape(); | ||
| if grad_shape != weight_shape { | ||
| let slices = vec![ | ||
| Slice::from(0..weight_shape[0]), | ||
| Slice::from(0..weight_shape[1]), | ||
| Slice::from(0..weight_shape[2]), | ||
| Slice::from(0..weight_shape[3]), | ||
| ]; | ||
| weight_grad = B::float_slice(weight_grad, &slices); | ||
| } | ||
| weight_grad | ||
| } | ||
| fn conv_transpose3d_weight_grad_no_groups<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| weight_shape: Shape, | ||
| options: ConvTransposeOptions<3>, | ||
| ) -> FloatTensor<B> { | ||
| let x_swapped = B::float_swap_dims(x, 0, 1); | ||
| let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1); | ||
| let weight_grad_swapped = B::conv3d( | ||
| output_grad_swapped, | ||
| x_swapped, | ||
| None, | ||
| ConvOptions::new(options.dilation, options.padding, options.stride, 1), | ||
| ); | ||
| let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1); | ||
| let grad_shape = weight_grad.shape(); | ||
| if grad_shape != weight_shape { | ||
| let slices = vec![ | ||
| Slice::from(0..weight_shape[0]), | ||
| Slice::from(0..weight_shape[1]), | ||
| Slice::from(0..weight_shape[2]), | ||
| Slice::from(0..weight_shape[3]), | ||
| Slice::from(0..weight_shape[4]), | ||
| ]; | ||
| weight_grad = B::float_slice(weight_grad, &slices); | ||
| } | ||
| weight_grad | ||
| } | ||
| fn conv_transpose1d_weight_grad_groups<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| mut weight_grad: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvTransposeOptions<1>, | ||
| ) -> FloatTensor<B> { | ||
| let [channels_in, increment_co, kernel_size] = weight_grad.shape().dims(); | ||
| let increment_ci = channels_in / options.groups; | ||
| let x_swapped = B::float_swap_dims(x, 0, 1); | ||
| let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1); | ||
| for g in 0..options.groups { | ||
| let start_idx_ci = g * increment_ci; | ||
| let end_idx_ci = (g + 1) * increment_ci; | ||
| let start_idx_co = g * increment_co; | ||
| let end_idx_co = (g + 1) * increment_co; | ||
| let x_slice = vec![Slice::new( | ||
| start_idx_ci as isize, | ||
| Some(end_idx_ci as isize), | ||
| 1, | ||
| )]; | ||
| let x = B::float_slice(x_swapped.clone(), &x_slice); | ||
| let grad_slice = vec![Slice::new( | ||
| start_idx_co as isize, | ||
| Some(end_idx_co as isize), | ||
| 1, | ||
| )]; | ||
| let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice); | ||
| let mut weight_grad_tmp = B::conv1d( | ||
| grad, | ||
| x, | ||
| None, | ||
| ConvOptions::new(options.dilation, options.padding, options.stride, 1), | ||
| ); | ||
| weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1); | ||
| let [_, _, kernel_size_tmp] = weight_grad_tmp.shape().dims(); | ||
| if kernel_size_tmp != kernel_size { | ||
| let slices = vec![ | ||
| Slice::from(0..increment_ci), | ||
| Slice::from(0..increment_co), | ||
| Slice::from(0..kernel_size), | ||
| ]; | ||
| weight_grad_tmp = B::float_slice(weight_grad_tmp, &slices); | ||
| } | ||
| weight_grad = B::float_slice_assign( | ||
| weight_grad, | ||
| &[ | ||
| Slice::from(start_idx_ci..end_idx_ci), | ||
| Slice::from(0..increment_co), | ||
| Slice::from(0..kernel_size), | ||
| ], | ||
| weight_grad_tmp, | ||
| ); | ||
| } | ||
| weight_grad | ||
| } | ||
| fn conv_transpose2d_weight_grad_groups<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| mut weight_grad: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvTransposeOptions<2>, | ||
| ) -> FloatTensor<B> { | ||
| let [channels_in, increment_co, kernel_size_1, kernel_size_2] = weight_grad.shape().dims(); | ||
| let increment_ci = channels_in / options.groups; | ||
| let x_swapped = B::float_swap_dims(x, 0, 1); | ||
| let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1); | ||
| for g in 0..options.groups { | ||
| let start_idx_ci = g * increment_ci; | ||
| let end_idx_ci = (g + 1) * increment_ci; | ||
| let start_idx_co = g * increment_co; | ||
| let end_idx_co = (g + 1) * increment_co; | ||
| let x_slice = vec![Slice::new( | ||
| start_idx_ci as isize, | ||
| Some(end_idx_ci as isize), | ||
| 1, | ||
| )]; | ||
| let x = B::float_slice(x_swapped.clone(), &x_slice); | ||
| let grad_slice = vec![Slice::new( | ||
| start_idx_co as isize, | ||
| Some(end_idx_co as isize), | ||
| 1, | ||
| )]; | ||
| let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice); | ||
| let mut weight_grad_tmp = B::conv2d( | ||
| grad, | ||
| x, | ||
| None, | ||
| ConvOptions::new(options.dilation, options.padding, options.stride, 1), | ||
| ); | ||
| weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1); | ||
| let [_, _, kernel_size_1_tmp, kernel_size_2_tmp] = weight_grad_tmp.shape().dims(); | ||
| if kernel_size_1_tmp != kernel_size_1 || kernel_size_2_tmp != kernel_size_2 { | ||
| let slices = vec![ | ||
| Slice::from(0..increment_ci), | ||
| Slice::from(0..increment_co), | ||
| Slice::from(0..kernel_size_1), | ||
| Slice::from(0..kernel_size_2), | ||
| ]; | ||
| weight_grad_tmp = B::float_slice(weight_grad_tmp, &slices); | ||
| } | ||
| weight_grad = B::float_slice_assign( | ||
| weight_grad, | ||
| &[ | ||
| Slice::from(start_idx_ci..end_idx_ci), | ||
| Slice::from(0..increment_co), | ||
| Slice::from(0..kernel_size_1), | ||
| Slice::from(0..kernel_size_2), | ||
| ], | ||
| weight_grad_tmp, | ||
| ); | ||
| } | ||
| weight_grad | ||
| } | ||
| fn conv_transpose3d_weight_grad_groups<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| mut weight_grad: FloatTensor<B>, | ||
| output_grad: FloatTensor<B>, | ||
| options: ConvTransposeOptions<3>, | ||
| ) -> FloatTensor<B> { | ||
| let [ | ||
| channels_in, | ||
| increment_co, | ||
| kernel_size_1, | ||
| kernel_size_2, | ||
| kernel_size_3, | ||
| ] = weight_grad.shape().dims(); | ||
| let increment_ci = channels_in / options.groups; | ||
| let x_swapped = B::float_swap_dims(x, 0, 1); | ||
| let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1); | ||
| for g in 0..options.groups { | ||
| let start_idx_ci = g * increment_ci; | ||
| let end_idx_ci = (g + 1) * increment_ci; | ||
| let start_idx_co = g * increment_co; | ||
| let end_idx_co = (g + 1) * increment_co; | ||
| let x_slice = vec![Slice::new( | ||
| start_idx_ci as isize, | ||
| Some(end_idx_ci as isize), | ||
| 1, | ||
| )]; | ||
| let x = B::float_slice(x_swapped.clone(), &x_slice); | ||
| let grad_slice = vec![Slice::new( | ||
| start_idx_co as isize, | ||
| Some(end_idx_co as isize), | ||
| 1, | ||
| )]; | ||
| let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice); | ||
| let mut weight_grad_tmp = B::conv3d( | ||
| grad, | ||
| x, | ||
| None, | ||
| ConvOptions::new(options.dilation, options.padding, options.stride, 1), | ||
| ); | ||
| weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1); | ||
| let [ | ||
| _, | ||
| _, | ||
| kernel_size_1_tmp, | ||
| kernel_size_2_tmp, | ||
| kernel_size_3_tmp, | ||
| ] = weight_grad_tmp.shape().dims(); | ||
| if kernel_size_1_tmp != kernel_size_1 | ||
| || kernel_size_2_tmp != kernel_size_2 | ||
| || kernel_size_3_tmp != kernel_size_3 | ||
| { | ||
| let slices = vec![ | ||
| Slice::from(0..increment_ci), | ||
| Slice::from(0..increment_co), | ||
| Slice::from(0..kernel_size_1), | ||
| Slice::from(0..kernel_size_2), | ||
| Slice::from(0..kernel_size_3), | ||
| ]; | ||
| weight_grad_tmp = B::float_slice(weight_grad_tmp, &slices); | ||
| } | ||
| weight_grad = B::float_slice_assign( | ||
| weight_grad, | ||
| &[ | ||
| Slice::from(start_idx_ci..end_idx_ci), | ||
| Slice::from(0..increment_co), | ||
| Slice::from(0..kernel_size_1), | ||
| Slice::from(0..kernel_size_2), | ||
| Slice::from(0..kernel_size_3), | ||
| ], | ||
| weight_grad_tmp, | ||
| ); | ||
| } | ||
| weight_grad | ||
| } | ||
| fn calculate_padding_out( | ||
| kernel_size: usize, | ||
| stride: usize, | ||
| padding: usize, | ||
| dilation: usize, | ||
| size_in: usize, | ||
| size_out: usize, | ||
| ) -> usize { | ||
| if stride <= 1 { | ||
| return 0; | ||
| } | ||
| let out = 1 | ||
| + ((size_in + 2 * padding - dilation * (kernel_size - 1) - 1) as f64 / stride as f64).ceil() | ||
| as usize; | ||
| i64::max(0, out as i64 - size_out as i64) as usize | ||
| } | ||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
| #[test] | ||
| fn test_calculate_output_size_1() { | ||
| let kernel_size = 3; | ||
| let stride = 1; | ||
| let padding = 1; | ||
| let size_in = 3; | ||
| let dilation = 1; | ||
| let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); | ||
| assert_eq!(size_out, 3); | ||
| } | ||
| #[test] | ||
| fn test_calculate_output_size_2() { | ||
| let kernel_size = 5; | ||
| let stride = 2; | ||
| let padding = 3; | ||
| let size_in = 27; | ||
| let dilation = 1; | ||
| let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); | ||
| assert_eq!(size_out, 15); | ||
| } | ||
| #[test] | ||
| fn test_calculate_output_size_3() { | ||
| let kernel_size = 5; | ||
| let stride = 2; | ||
| let padding = 3; | ||
| let size_in = 27; | ||
| let dilation = 2; | ||
| let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); | ||
| assert_eq!(size_out, 13); | ||
| } | ||
| #[test] | ||
| fn test_calculate_same_padding_1() { | ||
| let kernel_size = 3; | ||
| let stride = 1; | ||
| let size_in = 3; | ||
| let dilation = 1; | ||
| let padding = calculate_conv_padding(kernel_size, stride, size_in, size_in); | ||
| let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); | ||
| assert_eq!(size_in, size_out, "Expected size"); | ||
| } | ||
| #[test] | ||
| fn test_calculate_same_padding_2() { | ||
| let kernel_size = 3; | ||
| let stride = 2; | ||
| let size_in = 7; | ||
| let dilation = 1; | ||
| let padding = calculate_conv_padding(kernel_size, stride, size_in, size_in); | ||
| let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); | ||
| assert_eq!(size_in, size_out, "Expected size"); | ||
| } | ||
| #[test] | ||
| fn test_calculate_output_padding_1() { | ||
| let kernel_size = 3; | ||
| let stride = 2; | ||
| let size_in = 7; | ||
| let size_out = 10; | ||
| let dilation = 1; | ||
| let padding = calculate_conv_padding(kernel_size, stride, size_in, size_out); | ||
| let size_out_expected = | ||
| calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in); | ||
| assert_eq!(size_out, size_out_expected, "Expected size"); | ||
| } | ||
| #[test] | ||
| fn test_expect_conv2d_output_shape() { | ||
| // in channels: 3 | ||
| // out channels: 8 | ||
| // size in: [27, 3] | ||
| // kernel size: [5, 3] | ||
| let stride = [2, 1]; | ||
| let padding = [3, 1]; | ||
| let dilation = [2, 1]; | ||
| let shape = calculate_conv_output_shape( | ||
| &Shape::new([12, 3, 27, 3]), | ||
| &Shape::new([8, 3, 5, 3]), | ||
| &stride, | ||
| &padding, | ||
| &dilation, | ||
| ) | ||
| .unwrap(); | ||
| assert_eq!(shape, Shape::new([12, 8, 13, 3])) | ||
| } | ||
| } |
| use crate::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode}; | ||
| use crate::{ElementConversion, Shape, Slice, TensorMetadata, backend::Backend, ops::FloatTensor}; | ||
| use alloc::vec; | ||
| /// Reference implementation of grid_sample_2d that supports all options. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor being sampled from, shape (N, C, H_in, W_in) | ||
| /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1]. | ||
| /// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right | ||
| /// * `options` - Grid sampling options | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with shape (N, C, H_out, W_out) | ||
| pub fn float_grid_sample_2d_ref<B: Backend>( | ||
| tensor: FloatTensor<B>, | ||
| grid: FloatTensor<B>, | ||
| options: GridSampleOptions, | ||
| ) -> FloatTensor<B> { | ||
| match options.mode { | ||
| InterpolateMode::Bilinear => float_grid_sample_2d_bilinear::<B>( | ||
| tensor, | ||
| grid, | ||
| options.padding_mode, | ||
| options.align_corners, | ||
| ), | ||
| _ => todo!( | ||
| "Default implementation for grid_sample_2d with {:?} unimplemented", | ||
| options.mode | ||
| ), | ||
| } | ||
| } | ||
| /// Bilinear grid sampling implementation. | ||
| fn float_grid_sample_2d_bilinear<B: Backend>( | ||
| tensor: FloatTensor<B>, | ||
| grid: FloatTensor<B>, | ||
| padding_mode: GridSamplePaddingMode, | ||
| align_corners: bool, | ||
| ) -> FloatTensor<B> { | ||
| let n = tensor.shape().dims[0]; | ||
| let c = tensor.shape().dims[1]; | ||
| let h_in = tensor.shape().dims[2]; | ||
| let w_in = tensor.shape().dims[3]; | ||
| let h_out = grid.shape().dims[1]; | ||
| let w_out = grid.shape().dims[2]; | ||
| // Separate x and y coordinates from grid | ||
| // shape: (N, H_out, W_out, 1) | ||
| let grid_x_slice = vec![ | ||
| Slice::new(0, Some(n as isize), 1), | ||
| Slice::new(0, Some(h_out as isize), 1), | ||
| Slice::new(0, Some(w_out as isize), 1), | ||
| Slice::new(0, Some(1), 1), | ||
| ]; | ||
| let grid_y_slice = vec![ | ||
| Slice::new(0, Some(n as isize), 1), | ||
| Slice::new(0, Some(h_out as isize), 1), | ||
| Slice::new(0, Some(w_out as isize), 1), | ||
| Slice::new(1, Some(2), 1), | ||
| ]; | ||
| let grid_x = B::float_slice(grid.clone(), &grid_x_slice); | ||
| let grid_x = B::float_reshape(grid_x, Shape::new([n, 1, h_out, w_out])); | ||
| let grid_y = B::float_slice(grid.clone(), &grid_y_slice); | ||
| let grid_y = B::float_reshape(grid_y, Shape::new([n, 1, h_out, w_out])); | ||
| // Convert normalized grid coordinates [-1, 1] to pixel coordinates | ||
| let w_in_f = w_in as f64; | ||
| let h_in_f = h_in as f64; | ||
| let (grid_x, grid_y) = if align_corners { | ||
| // align_corners=true: x_pixel = (x_norm + 1) * (width - 1) / 2 | ||
| // Maps -1 to 0 and 1 to width - 1 | ||
| let grid_x = B::float_add_scalar(grid_x, 1.0f32.elem()); | ||
| let grid_x = B::float_mul_scalar(grid_x, ((w_in_f - 1.0) / 2.0).elem()); | ||
| let grid_y = B::float_add_scalar(grid_y, 1.0f32.elem()); | ||
| let grid_y = B::float_mul_scalar(grid_y, ((h_in_f - 1.0) / 2.0).elem()); | ||
| (grid_x, grid_y) | ||
| } else { | ||
| // align_corners=false: x_pixel = (x_norm + 1) * width / 2 - 0.5 | ||
| // Maps -1 to -0.5 and 1 to width - 0.5 | ||
| let grid_x = B::float_add_scalar(grid_x, 1.0f32.elem()); | ||
| let grid_x = B::float_mul_scalar(grid_x, (w_in_f / 2.0).elem()); | ||
| let grid_x = B::float_sub_scalar(grid_x, 0.5f32.elem()); | ||
| let grid_y = B::float_add_scalar(grid_y, 1.0f32.elem()); | ||
| let grid_y = B::float_mul_scalar(grid_y, (h_in_f / 2.0).elem()); | ||
| let grid_y = B::float_sub_scalar(grid_y, 0.5f32.elem()); | ||
| (grid_x, grid_y) | ||
| }; | ||
| // Apply padding mode to coordinates | ||
| let (grid_x, grid_y) = match padding_mode { | ||
| GridSamplePaddingMode::Border => { | ||
| // Clamp coordinates to valid range [0, size-1] | ||
| let grid_x = B::float_clamp(grid_x, 0.0f32.elem(), ((w_in - 1) as f32).elem()); | ||
| let grid_y = B::float_clamp(grid_y, 0.0f32.elem(), ((h_in - 1) as f32).elem()); | ||
| (grid_x, grid_y) | ||
| } | ||
| GridSamplePaddingMode::Reflection => { | ||
| // Reflect coordinates at boundaries | ||
| // For now, use a simplified reflection that works for common cases | ||
| let grid_x = reflect_coordinates::<B>(grid_x, w_in_f); | ||
| let grid_y = reflect_coordinates::<B>(grid_y, h_in_f); | ||
| (grid_x, grid_y) | ||
| } | ||
| GridSamplePaddingMode::Zeros => { | ||
| // Keep coordinates as-is, we'll mask out-of-bounds later | ||
| (grid_x, grid_y) | ||
| } | ||
| }; | ||
| // Get floor indices for the four corners | ||
| let grid_x_floored = B::float_floor(grid_x.clone()); | ||
| let grid_y_floored = B::float_floor(grid_y.clone()); | ||
| // Compute interpolation weights (fractional part) | ||
| let x_frac = B::float_sub(grid_x.clone(), grid_x_floored.clone()); | ||
| let y_frac = B::float_sub(grid_y.clone(), grid_y_floored.clone()); | ||
| // Convert to integer indices | ||
| let x0 = B::float_into_int(grid_x_floored.clone()); | ||
| let y0 = B::float_into_int(grid_y_floored.clone()); | ||
| let x1 = B::float_into_int(B::float_add_scalar(grid_x_floored, 1.0f32.elem())); | ||
| let y1 = B::float_into_int(B::float_add_scalar(grid_y_floored, 1.0f32.elem())); | ||
| // Create masks for out-of-bounds coordinates (only used for zeros padding) | ||
| let (mask_00, mask_01, mask_10, mask_11) = if padding_mode == GridSamplePaddingMode::Zeros { | ||
| let x0_valid = B::int_greater_equal_elem(x0.clone(), 0.elem()); | ||
| let x0_valid = B::bool_and( | ||
| x0_valid, | ||
| B::int_lower_elem(x0.clone(), (w_in as i32).elem()), | ||
| ); | ||
| let x1_valid = B::int_greater_equal_elem(x1.clone(), 0.elem()); | ||
| let x1_valid = B::bool_and( | ||
| x1_valid, | ||
| B::int_lower_elem(x1.clone(), (w_in as i32).elem()), | ||
| ); | ||
| let y0_valid = B::int_greater_equal_elem(y0.clone(), 0.elem()); | ||
| let y0_valid = B::bool_and( | ||
| y0_valid, | ||
| B::int_lower_elem(y0.clone(), (h_in as i32).elem()), | ||
| ); | ||
| let y1_valid = B::int_greater_equal_elem(y1.clone(), 0.elem()); | ||
| let y1_valid = B::bool_and( | ||
| y1_valid, | ||
| B::int_lower_elem(y1.clone(), (h_in as i32).elem()), | ||
| ); | ||
| ( | ||
| Some(B::bool_and(x0_valid.clone(), y0_valid.clone())), | ||
| Some(B::bool_and(x0_valid.clone(), y1_valid.clone())), | ||
| Some(B::bool_and(x1_valid.clone(), y0_valid)), | ||
| Some(B::bool_and(x1_valid, y1_valid)), | ||
| ) | ||
| } else { | ||
| (None, None, None, None) | ||
| }; | ||
| // Clamp indices to valid range for gather | ||
| let x0_clamped = B::int_clamp(x0, 0.elem(), ((w_in - 1) as i32).elem()); | ||
| let x1_clamped = B::int_clamp(x1, 0.elem(), ((w_in - 1) as i32).elem()); | ||
| let y0_clamped = B::int_clamp(y0, 0.elem(), ((h_in - 1) as i32).elem()); | ||
| let y1_clamped = B::int_clamp(y1, 0.elem(), ((h_in - 1) as i32).elem()); | ||
| // Reshape indices for gather operation | ||
| let y0_idx = B::int_reshape(y0_clamped.clone(), Shape::new([n, 1, h_out, w_out, 1])); | ||
| let y0_idx = B::int_expand(y0_idx, Shape::new([n, c, h_out, w_out, w_in])); | ||
| let y1_idx = B::int_reshape(y1_clamped.clone(), Shape::new([n, 1, h_out, w_out, 1])); | ||
| let y1_idx = B::int_expand(y1_idx, Shape::new([n, c, h_out, w_out, w_in])); | ||
| let x0_idx = B::int_reshape(x0_clamped, Shape::new([n, 1, h_out, w_out, 1])); | ||
| let x0_idx = B::int_expand(x0_idx, Shape::new([n, c, h_out, w_out, 1])); | ||
| let x1_idx = B::int_reshape(x1_clamped, Shape::new([n, 1, h_out, w_out, 1])); | ||
| let x1_idx = B::int_expand(x1_idx, Shape::new([n, c, h_out, w_out, 1])); | ||
| // Reshape tensor for gather operation | ||
| let tensor = B::float_reshape(tensor, Shape::new([n, c, h_in, 1, w_in])); | ||
| let tensor = B::float_expand(tensor, Shape::new([n, c, h_in, w_out, w_in])); | ||
| // Gather samples from the four corners | ||
| let sample_00 = B::float_gather(2, tensor.clone(), y0_idx.clone()); | ||
| let sample_00 = B::float_gather(4, sample_00, x0_idx.clone()); | ||
| let sample_01 = B::float_gather(2, tensor.clone(), y1_idx.clone()); | ||
| let sample_01 = B::float_gather(4, sample_01, x0_idx.clone()); | ||
| let sample_10 = B::float_gather(2, tensor.clone(), y0_idx); | ||
| let sample_10 = B::float_gather(4, sample_10, x1_idx.clone()); | ||
| let sample_11 = B::float_gather(2, tensor, y1_idx); | ||
| let sample_11 = B::float_gather(4, sample_11, x1_idx); | ||
| // Reshape samples to (N, C, H_out, W_out) | ||
| let sample_00 = B::float_reshape(sample_00, Shape::new([n, c, h_out, w_out])); | ||
| let sample_01 = B::float_reshape(sample_01, Shape::new([n, c, h_out, w_out])); | ||
| let sample_10 = B::float_reshape(sample_10, Shape::new([n, c, h_out, w_out])); | ||
| let sample_11 = B::float_reshape(sample_11, Shape::new([n, c, h_out, w_out])); | ||
| // Apply masks for zeros padding (set out-of-bounds samples to 0) | ||
| let (sample_00, sample_01, sample_10, sample_11) = | ||
| if padding_mode == GridSamplePaddingMode::Zeros { | ||
| let mask_00 = mask_00.unwrap(); | ||
| let mask_01 = mask_01.unwrap(); | ||
| let mask_10 = mask_10.unwrap(); | ||
| let mask_11 = mask_11.unwrap(); | ||
| let mask_00_inv = B::bool_not(mask_00); | ||
| let mask_00_inv = B::bool_reshape(mask_00_inv, Shape::new([n, 1, h_out, w_out])); | ||
| let mask_00_inv = B::bool_expand(mask_00_inv, Shape::new([n, c, h_out, w_out])); | ||
| let mask_01_inv = B::bool_not(mask_01); | ||
| let mask_01_inv = B::bool_reshape(mask_01_inv, Shape::new([n, 1, h_out, w_out])); | ||
| let mask_01_inv = B::bool_expand(mask_01_inv, Shape::new([n, c, h_out, w_out])); | ||
| let mask_10_inv = B::bool_not(mask_10); | ||
| let mask_10_inv = B::bool_reshape(mask_10_inv, Shape::new([n, 1, h_out, w_out])); | ||
| let mask_10_inv = B::bool_expand(mask_10_inv, Shape::new([n, c, h_out, w_out])); | ||
| let mask_11_inv = B::bool_not(mask_11); | ||
| let mask_11_inv = B::bool_reshape(mask_11_inv, Shape::new([n, 1, h_out, w_out])); | ||
| let mask_11_inv = B::bool_expand(mask_11_inv, Shape::new([n, c, h_out, w_out])); | ||
| ( | ||
| B::float_mask_fill(sample_00, mask_00_inv, 0.0f32.elem()), | ||
| B::float_mask_fill(sample_01, mask_01_inv, 0.0f32.elem()), | ||
| B::float_mask_fill(sample_10, mask_10_inv, 0.0f32.elem()), | ||
| B::float_mask_fill(sample_11, mask_11_inv, 0.0f32.elem()), | ||
| ) | ||
| } else { | ||
| (sample_00, sample_01, sample_10, sample_11) | ||
| }; | ||
| // Compute bilinear interpolation weights | ||
| let one_minus_x = B::float_neg(x_frac.clone()); | ||
| let one_minus_x = B::float_add_scalar(one_minus_x, 1.0f32.elem()); | ||
| let one_minus_y = B::float_neg(y_frac.clone()); | ||
| let one_minus_y = B::float_add_scalar(one_minus_y, 1.0f32.elem()); | ||
| let weight_00 = B::float_mul(one_minus_x.clone(), one_minus_y.clone()); | ||
| let weight_01 = B::float_mul(one_minus_x.clone(), y_frac.clone()); | ||
| let weight_10 = B::float_mul(x_frac.clone(), one_minus_y); | ||
| let weight_11 = B::float_mul(x_frac, y_frac); | ||
| // Bilinear interpolation | ||
| let result = B::float_mul(sample_00, weight_00); | ||
| let result = B::float_add(result, B::float_mul(sample_01, weight_01)); | ||
| let result = B::float_add(result, B::float_mul(sample_10, weight_10)); | ||
| B::float_add(result, B::float_mul(sample_11, weight_11)) | ||
| } | ||
| /// Reflect coordinates at boundaries for reflection padding. | ||
| /// | ||
| /// Uses the formula: reflected = 2 * bound - x for out-of-bounds coordinates. | ||
| fn reflect_coordinates<B: Backend>(coords: FloatTensor<B>, size: f64) -> FloatTensor<B> { | ||
| // Simple reflection: clamp to [0, size-1] after reflecting | ||
| // For values < 0: reflect at 0 -> -x | ||
| // For values >= size: reflect at size-1 -> 2*(size-1) - x | ||
| // This is a simplified implementation - full reflection would handle multiple reflections | ||
| let max_val = (size - 1.0) as f32; | ||
| // First handle negative values by taking absolute value | ||
| let coords = B::float_abs(coords); | ||
| // Then handle values > max by reflecting: 2*max - x | ||
| // But we need to detect which values need this | ||
| // Simplified: just clamp for now, proper reflection is complex | ||
| B::float_clamp(coords, 0.0f32.elem(), max_val.elem()) | ||
| } |
| /// Module with convolution operations. | ||
| pub mod conv; | ||
| /// Module with cat operation | ||
| pub(crate) mod cat; | ||
| /// Module with repeat operation | ||
| pub(crate) mod repeat_dim; | ||
| /// Module with unfold operations. | ||
| pub mod unfold; | ||
| /// Module with pooling operations. | ||
| pub mod pool; | ||
| /// Module for grid_sample operations | ||
| pub mod grid_sample; | ||
| mod base; | ||
| pub use base::*; |
| use crate::{ | ||
| Shape, TensorMetadata, | ||
| backend::Backend, | ||
| ops::{FloatTensor, IntTensor}, | ||
| }; | ||
| use super::{MaxPool1dBackward, MaxPool1dWithIndices}; | ||
| pub(crate) fn avg_pool1d_from_2d<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| kernel_size: usize, | ||
| stride: usize, | ||
| padding: usize, | ||
| count_include_pad: bool, | ||
| ) -> FloatTensor<B> { | ||
| let [batch_size, channels, length] = x.shape().dims(); | ||
| let x = B::float_reshape(x, Shape::from([batch_size, channels, length, 1])); | ||
| let x = B::avg_pool2d( | ||
| x, | ||
| [kernel_size, 1], | ||
| [stride, 1], | ||
| [padding, 0], | ||
| count_include_pad, | ||
| ); | ||
| let [batch_size, channels, length, _] = x.shape().dims(); | ||
| B::float_reshape(x, Shape::from([batch_size, channels, length])) | ||
| } | ||
| pub(crate) fn avg_pool1d_backward_from_2d<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| grad: FloatTensor<B>, | ||
| kernel_size: usize, | ||
| stride: usize, | ||
| padding: usize, | ||
| count_include_pad: bool, | ||
| ) -> FloatTensor<B> { | ||
| let [batch_size, channels, length_in] = x.shape().dims(); | ||
| let [_, _, length_out] = grad.shape().dims(); | ||
| let x = B::float_reshape(x, Shape::from([batch_size, channels, length_in, 1])); | ||
| let grad_x = B::float_reshape(grad, Shape::from([batch_size, channels, length_out, 1])); | ||
| let grad_x = B::avg_pool2d_backward( | ||
| x, | ||
| grad_x, | ||
| [kernel_size, 1], | ||
| [stride, 1], | ||
| [padding, 0], | ||
| count_include_pad, | ||
| ); | ||
| B::float_reshape(grad_x, Shape::from([batch_size, channels, length_in])) | ||
| } | ||
| pub(crate) fn adaptive_avg_pool1d_from_2d<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| output_size: usize, | ||
| ) -> FloatTensor<B> { | ||
| let [batch_size, channels, length] = x.shape().dims(); | ||
| let x = B::float_reshape(x, Shape::from([batch_size, channels, length, 1])); | ||
| let x = B::adaptive_avg_pool2d(x, [output_size, 1]); | ||
| let [batch_size, channels, length, _] = x.shape().dims(); | ||
| B::float_reshape(x, Shape::from([batch_size, channels, length])) | ||
| } | ||
| pub(crate) fn adaptive_avg_pool1d_backward_from_2d<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| grad: FloatTensor<B>, | ||
| ) -> FloatTensor<B> { | ||
| let [batch_size, channels, length_in] = x.shape().dims(); | ||
| let [_, _, length_out] = grad.shape().dims(); | ||
| let x = B::float_reshape(x, Shape::from([batch_size, channels, length_in, 1])); | ||
| let grad_x = B::float_reshape(grad, Shape::from([batch_size, channels, length_out, 1])); | ||
| let grad_x = B::adaptive_avg_pool2d_backward(x, grad_x); | ||
| B::float_reshape(grad_x, Shape::from([batch_size, channels, length_in])) | ||
| } | ||
| pub(crate) fn max_pool1d_from_2d<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| kernel_size: usize, | ||
| stride: usize, | ||
| padding: usize, | ||
| dilation: usize, | ||
| ) -> FloatTensor<B> { | ||
| let [batch_size, channels, length] = x.shape().dims(); | ||
| let x = B::float_reshape(x, Shape::from([batch_size, channels, length, 1])); | ||
| let x = B::max_pool2d( | ||
| x, | ||
| [kernel_size, 1], | ||
| [stride, 1], | ||
| [padding, 0], | ||
| [dilation, 1], | ||
| ); | ||
| let [batch_size, channels, length, _] = x.shape().dims(); | ||
| B::float_reshape(x, Shape::from([batch_size, channels, length])) | ||
| } | ||
| pub(crate) fn max_pool1d_with_indices_from_2d<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| kernel_size: usize, | ||
| stride: usize, | ||
| padding: usize, | ||
| dilation: usize, | ||
| ) -> MaxPool1dWithIndices<B> { | ||
| let [batch_size, channels, length] = x.shape().dims(); | ||
| let x = B::float_reshape(x, Shape::from([batch_size, channels, 1, length])); | ||
| let x = B::max_pool2d_with_indices( | ||
| x, | ||
| [1, kernel_size], | ||
| [1, stride], | ||
| [0, padding], | ||
| [1, dilation], | ||
| ); | ||
| let [batch_size, channels, _, length] = x.output.shape().dims(); | ||
| let output = B::float_reshape(x.output, Shape::from([batch_size, channels, length])); | ||
| let indices = B::int_reshape(x.indices, Shape::from([batch_size, channels, length])); | ||
| MaxPool1dWithIndices::new(output, indices) | ||
| } | ||
| pub(crate) fn max_pool1d_with_indices_backward_from_2d<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| kernel_size: usize, | ||
| stride: usize, | ||
| padding: usize, | ||
| dilation: usize, | ||
| output_grad: FloatTensor<B>, | ||
| indices: IntTensor<B>, | ||
| ) -> MaxPool1dBackward<B> { | ||
| let [batch_size, channels, length_in] = x.shape().dims(); | ||
| let [_, _, length_out] = output_grad.shape().dims(); | ||
| let x = B::float_reshape(x, Shape::from([batch_size, channels, length_in, 1])); | ||
| let grad_x = B::float_reshape( | ||
| output_grad, | ||
| Shape::from([batch_size, channels, length_out, 1]), | ||
| ); | ||
| let indices = B::int_reshape(indices, Shape::from([batch_size, channels, length_out, 1])); | ||
| let grad_x = B::max_pool2d_with_indices_backward( | ||
| x, | ||
| [kernel_size, 1], | ||
| [stride, 1], | ||
| [padding, 0], | ||
| [dilation, 1], | ||
| grad_x, | ||
| indices, | ||
| ) | ||
| .x_grad; | ||
| MaxPool1dBackward::new(B::float_reshape( | ||
| grad_x, | ||
| Shape::from([batch_size, channels, length_in]), | ||
| )) | ||
| } |
| use crate::{BasicOps, Slice, TensorKind, TensorMetadata, backend::Backend}; | ||
| use alloc::vec::Vec; | ||
| pub(crate) fn repeat_with_slice_assign<B: Backend, K: TensorKind<B> + BasicOps<B>>( | ||
| tensor: K::Primitive, | ||
| dim: usize, | ||
| times: usize, | ||
| ) -> K::Primitive { | ||
| let shape = tensor.shape(); | ||
| let device = K::device(&tensor); | ||
| let dtype = tensor.dtype(); | ||
| let original_dim_length = shape[dim]; | ||
| let shape = shape.repeat(dim, times).unwrap(); | ||
| let mut tensor_output = K::empty(shape.clone(), &device, dtype); | ||
| let indices_select_all = shape.iter().map(|d| 0..*d).collect::<Vec<_>>(); | ||
| let mut output_index = 0; | ||
| for _ in 0..times { | ||
| let mut indices = indices_select_all.clone(); | ||
| indices[dim] = output_index..output_index + original_dim_length; | ||
| output_index += original_dim_length; | ||
| // Convert ranges to Slice | ||
| let slices: Vec<Slice> = indices | ||
| .iter() | ||
| .map(|r| Slice::new(r.start as isize, Some(r.end as isize), 1)) | ||
| .collect(); | ||
| tensor_output = K::slice_assign(tensor_output, &slices, tensor.clone()); | ||
| } | ||
| tensor_output | ||
| } |
| use super::{ConvOptions, UnfoldOptions}; | ||
| use crate::backend::Backend; | ||
| use crate::ops::FloatTensor; | ||
| use crate::{ElementConversion, Shape, TensorData, TensorMetadata}; | ||
| use alloc::vec; | ||
| use alloc::vec::Vec; | ||
| /// Constructs a special weight tensor used for unfolding. | ||
| /// | ||
| /// # Notes | ||
| /// | ||
| /// The idea behind using convolution for unfolding is to leverage the sliding window mechanism of | ||
| /// convolution. By creating a weight tensor with ones in a particular pattern, we are able to borrow | ||
| /// the convolution operation's mechanism as it moves across the input tensor, picking up the desired | ||
| /// values in the pattern of the unfolding operation. | ||
| pub(crate) fn create_unfolding_weight<B: Backend>( | ||
| in_channels: usize, | ||
| kernel_size: [usize; 2], | ||
| device: &B::Device, | ||
| ) -> FloatTensor<B> { | ||
| let shape = Shape::new([ | ||
| in_channels * kernel_size[0] * kernel_size[1], | ||
| in_channels, | ||
| kernel_size[0], | ||
| kernel_size[1], | ||
| ]); | ||
| let mut strides = [0; 4]; | ||
| let mut current = 1; | ||
| shape | ||
| .dims | ||
| .iter() | ||
| .enumerate() | ||
| .rev() | ||
| .for_each(|(index, val)| { | ||
| strides[index] = current; | ||
| current *= val; | ||
| }); | ||
| let num_elements = shape.num_elements(); | ||
| let mut weight: Vec<B::FloatElem> = vec![0.0.elem(); num_elements]; | ||
| for k in 0..in_channels { | ||
| for i in 0..kernel_size[0] { | ||
| for j in 0..kernel_size[1] { | ||
| let output_channel = k * kernel_size[0] * kernel_size[1] + i * kernel_size[1] + j; | ||
| let index = | ||
| output_channel * strides[0] + k * strides[1] + i * strides[2] + j * strides[3]; | ||
| weight[index] = 1.elem(); | ||
| } | ||
| } | ||
| } | ||
| B::float_from_data(TensorData::new(weight, shape), device) | ||
| } | ||
| /// Compute the unfold4d operation using the conv2d operations. | ||
| pub(crate) fn unfold4d_using_conv2d<B: Backend>( | ||
| x: FloatTensor<B>, | ||
| kernel_size: [usize; 2], | ||
| options: UnfoldOptions, | ||
| ) -> FloatTensor<B> { | ||
| let [_batch_size, in_channels, _in_height, _in_width] = x.shape().dims(); | ||
| let weight = create_unfolding_weight::<B>(in_channels, kernel_size, &B::float_device(&x)); | ||
| let unfolded = B::conv2d( | ||
| x, | ||
| weight, | ||
| None, | ||
| ConvOptions::new(options.stride, options.padding, options.dilation, 1), | ||
| ); | ||
| let [batch_size, channels_out, out_height, out_width] = unfolded.shape().dims(); | ||
| B::float_reshape( | ||
| unfolded, | ||
| Shape::new([batch_size, channels_out, out_height * out_width]), | ||
| ) | ||
| } | ||
| /// Calculate the number of unfolding windows that can be extracted from a dimension of given size. | ||
| pub fn calculate_unfold_windows(dim_size: usize, window_size: usize, step_size: usize) -> usize { | ||
| assert!(step_size > 0); | ||
| let x = dim_size + step_size; | ||
| if x < window_size { | ||
| 0 | ||
| } else { | ||
| (x - window_size) / step_size | ||
| } | ||
| } | ||
| /// Calculate the output shape for an unfold operation. | ||
| /// | ||
| /// The operation yields a view with all complete windows of size `size` in dimension `dim`; | ||
| /// where windows are advanced by `step` at each index. | ||
| /// | ||
| /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `shape` - The input shape to unfold; of shape ``[pre=..., dim shape, post=...]`` | ||
| /// * `dim` - the dimension to unfold. | ||
| /// * `size` - the size of each unfolded window. | ||
| /// * `step` - the step between each window. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A shape with ``[pre=..., windows, post=..., size]``. | ||
| pub fn calculate_unfold_shape<S: Into<Shape>>( | ||
| shape: S, | ||
| dim: usize, | ||
| size: usize, | ||
| step: usize, | ||
| ) -> Shape { | ||
| let mut shape = shape.into(); | ||
| let d_shape = shape[dim]; | ||
| let windows = calculate_unfold_windows(d_shape, size, step); | ||
| shape[dim] = windows; | ||
| shape.push(size); | ||
| shape | ||
| } | ||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
| #[test] | ||
| fn test_calculate_unfold_windows() { | ||
| assert_eq!(calculate_unfold_windows(2, 5, 1), 0); | ||
| assert_eq!(calculate_unfold_windows(2, 3, 1), 0); | ||
| assert_eq!(calculate_unfold_windows(3, 3, 1), 1); | ||
| assert_eq!(calculate_unfold_windows(4, 3, 1), 2); | ||
| assert_eq!(calculate_unfold_windows(5, 3, 1), 3); | ||
| assert_eq!(calculate_unfold_windows(2, 3, 2), 0); | ||
| assert_eq!(calculate_unfold_windows(3, 3, 2), 1); | ||
| assert_eq!(calculate_unfold_windows(4, 3, 2), 1); | ||
| assert_eq!(calculate_unfold_windows(5, 3, 2), 2); | ||
| } | ||
| #[test] | ||
| fn test_calculate_unfold_shape() { | ||
| assert_eq!( | ||
| calculate_unfold_shape([2, 6, 6], 1, 3, 2), | ||
| Shape::new([2, 2, 6, 3]) | ||
| ); | ||
| } | ||
| } |
| use crate::quantization::QuantScheme; | ||
| use alloc::vec::Vec; | ||
| use crate::{ | ||
| Device, Shape, TensorData, TensorMetadata, TensorPrimitive, | ||
| backend::{Backend, ExecutionError}, | ||
| quantization::{ | ||
| Calibration, QTensorPrimitive, QuantPropagation, QuantizationParametersPrimitive, | ||
| compute_q_params_primitive, compute_range_primitive, | ||
| }, | ||
| }; | ||
| use super::{BoolTensor, FloatElem, FloatTensor, IntElem, IntTensor, QuantizedTensor}; | ||
| /// Automatically applies `dequantization -> float operation -> quantization`. | ||
| /// | ||
| /// Used for tensor ops that should always return a quantized output. | ||
| #[macro_export] | ||
| macro_rules! dequant_op_quant { | ||
| // Binary tensor float op w/ lhs & rhs | ||
| ( | ||
| ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr | ||
| ) => {{ | ||
| // Heuristic: prioritize lhs scheme | ||
| let scheme = $t1.scheme().clone(); | ||
| let t1_f = <$ty>::dequantize($t1); | ||
| let t2_f = <$ty>::dequantize($t2); | ||
| #[allow(clippy::redundant_closure_call)] | ||
| let out_f = $float_op(t1_f, t2_f); | ||
| <$ty>::quantize_dynamic(out_f, &scheme) | ||
| }}; | ||
| // Unary tensor float op | ||
| ( | ||
| ty $ty:ty, float_op $float_op:expr, $tensor:expr | ||
| ) => {{ | ||
| let scheme = $tensor.scheme().clone(); | ||
| let tensor_f = <$ty>::dequantize($tensor); | ||
| #[allow(clippy::redundant_closure_call)] | ||
| let out_f = $float_op(tensor_f); | ||
| <$ty>::quantize_dynamic(out_f, &scheme) | ||
| }}; | ||
| } | ||
| /// Automatically applies `dequantization -> float operation [-> quantization]`. | ||
| /// | ||
| /// The output quantization step is optional. | ||
| /// It is only performed when the input quantization scheme is propagated. | ||
| #[macro_export] | ||
| macro_rules! dequant_op_flow { | ||
| // Binary tensor float op w/ lhs & rhs | ||
| ( | ||
| ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr | ||
| ) => {{ | ||
| // Heuristic: prioritize lhs scheme | ||
| let scheme = $t1.scheme().clone(); | ||
| let propagation = $t1.propagation(); | ||
| let t1_f = <$ty>::dequantize($t1); | ||
| let t2_f = <$ty>::dequantize($t2); | ||
| #[allow(clippy::redundant_closure_call)] | ||
| let out_f = $float_op(t1_f, t2_f); | ||
| match propagation { | ||
| QuantPropagation::Propagate => { | ||
| TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme)) | ||
| } | ||
| QuantPropagation::Inhibit => TensorPrimitive::Float(out_f), | ||
| } | ||
| }}; | ||
| // Unary tensor float op | ||
| ( | ||
| ty $ty:ty, float_op $float_op:expr, $tensor:expr | ||
| ) => {{ | ||
| let scheme = $tensor.scheme().clone(); | ||
| let propagation = $tensor.propagation(); | ||
| let tensor_f = <$ty>::dequantize($tensor); | ||
| #[allow(clippy::redundant_closure_call)] | ||
| let out_f = $float_op(tensor_f); | ||
| match propagation { | ||
| QuantPropagation::Propagate => { | ||
| TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme)) | ||
| } | ||
| QuantPropagation::Inhibit => TensorPrimitive::Float(out_f), | ||
| } | ||
| }}; | ||
| } | ||
| /// Operations on quantized tensors. | ||
| /// | ||
| /// # Return Type Semantics | ||
| /// | ||
| /// The return type of each operation indicates how quantization is handled: | ||
| /// | ||
| /// ## [`QuantizedTensor<B>`] | ||
| /// If the method returns a `QuantizedTensor<B>`, the operation is expected to preserve the quantized | ||
| /// representation. Implementations should avoid dequantizing when possible to maintain performance. | ||
| /// For example, shape or layout changes such as expand or transpose preserve quantization. | ||
| /// | ||
| /// *Note: while this currently doesn't affect the quantized tensor parameters (only per-tensor is | ||
| /// supported at the time of writing), other quantization levels (e.g., per-block) may require re-ordering | ||
| /// the quantization parameters to match the new layout.* | ||
| /// | ||
| /// | ||
| /// ## [`TensorPrimitive<B>`] | ||
| /// If the method returns a `TensorPrimitive<B>` enum, the return type should align with propagation | ||
| /// strategy specified in the quantization scheme. The output should remain quantized ([`TensorPrimitive::QFloat`]) | ||
| /// returned in floating-point form ([`TensorPrimitive::Float`]). | ||
| /// | ||
| /// This distinction allows for fine-grained control over mixed-precision flows while still operating | ||
| /// through a unified API. | ||
| pub trait QTensorOps<B: Backend> { | ||
| /// Creates a new tensor from the data structure. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `data` - The data structure. | ||
| /// * `device` - The device to create the tensor on. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the given data. | ||
| fn q_from_data(data: TensorData, device: &Device<B>) -> QuantizedTensor<B>; | ||
| /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters. | ||
| fn quantize( | ||
| tensor: FloatTensor<B>, | ||
| scheme: &QuantScheme, | ||
| qparams: QuantizationParametersPrimitive<B>, | ||
| ) -> QuantizedTensor<B>; | ||
| /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme. | ||
| fn quantize_dynamic(tensor: FloatTensor<B>, scheme: &QuantScheme) -> QuantizedTensor<B> { | ||
| // Dynamically compute min/max tensor range and qparams before quantizing | ||
| let (min, max) = compute_range_primitive::<B>(scheme, tensor.clone(), &Calibration::MinMax); | ||
| let qparams = compute_q_params_primitive(scheme, min, max); | ||
| Self::quantize(tensor, scheme, qparams) | ||
| } | ||
| /// Convert the tensor back to a higher precision data type. | ||
| fn dequantize(tensor: QuantizedTensor<B>) -> FloatTensor<B>; | ||
| /// Gets the device of the tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The device of the tensor. | ||
| fn q_device(tensor: &QuantizedTensor<B>) -> Device<B>; | ||
| /// Moves the tensor to the given device. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// * `device` - The device to move the tensor to. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor on the given device. | ||
| fn q_to_device(tensor: QuantizedTensor<B>, device: &Device<B>) -> QuantizedTensor<B>; | ||
| /// Reshapes a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to reshape. | ||
| /// * `shape` - The new shape of the tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the new shape. | ||
| fn q_reshape(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>; | ||
| /// Converts the tensor to a data structure. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The data structure with the tensor's data. | ||
| fn q_into_data( | ||
| tensor: QuantizedTensor<B>, | ||
| ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send; | ||
| /// Detaches a tensor from the computation graph. | ||
| fn q_detach(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> { | ||
| // Should only be overridden by autodiff backends. | ||
| tensor | ||
| } | ||
| /// Sets the `require_grad` flag of a tensor. | ||
| fn q_set_require_grad(tensor: QuantizedTensor<B>, _require_grad: bool) -> QuantizedTensor<B> { | ||
| // Should only be overridden by autodiff backends. | ||
| tensor | ||
| } | ||
| /// Returns the `require_grad` flag of a tensor. | ||
| fn q_is_require_grad(_tensor: &QuantizedTensor<B>) -> bool { | ||
| // Should only be overridden by autodiff backends. | ||
| false | ||
| } | ||
| /// Broadcasts the `tensor` to the given `shape`. | ||
| fn q_expand(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>; | ||
| /// Transposes a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to transpose. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The transposed tensor. | ||
| fn q_transpose(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> { | ||
| let ndims = tensor.shape().num_dims(); | ||
| Self::q_swap_dims(tensor, ndims - 2, ndims - 1) | ||
| } | ||
| /// Swaps two dimensions of a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to swap the dimensions of. | ||
| /// * `dim1` - The first dimension to swap. | ||
| /// * `dim2` - The second dimension to swap. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the dimensions swapped. | ||
| fn q_swap_dims(tensor: QuantizedTensor<B>, dim1: usize, dim2: usize) -> QuantizedTensor<B>; | ||
| /// Permutes the dimensions of a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to permute the dimensions of. | ||
| /// * `axes` - The new order of the dimensions. | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the dimensions permuted. | ||
| fn q_permute(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>; | ||
| /// Reverse the order of elements in a tensor along the given axes. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to reverse. | ||
| /// * `axes` - The axes to reverse. | ||
| /// | ||
| /// The tensor with the elements reversed. | ||
| fn q_flip(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>; | ||
| /// Select tensor elements along the given dimension corresponding for the given indices. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to select from. | ||
| /// * `dim` - The dimension to select from. | ||
| /// * `indices` - The indices to select. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The selected elements. | ||
| fn q_select( | ||
| tensor: QuantizedTensor<B>, | ||
| dim: usize, | ||
| indices: IntTensor<B>, | ||
| ) -> QuantizedTensor<B>; | ||
| /// Select tensor elements corresponding to the given slices. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to select from. | ||
| /// * `slices` - The slices specifying ranges and steps for each dimension. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The selected elements in a new tensor. | ||
| fn q_slice(tensor: QuantizedTensor<B>, slices: &[crate::Slice]) -> QuantizedTensor<B>; | ||
| /// Gather elements from a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `dim` - The dimension to gather from. | ||
| /// * `tensor` - The tensor to gather from. | ||
| /// * `indices` - The indices to gather. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The gathered elements. | ||
| fn q_gather( | ||
| dim: usize, | ||
| tensor: QuantizedTensor<B>, | ||
| indices: IntTensor<B>, | ||
| ) -> QuantizedTensor<B> { | ||
| // Default implementation. Backends can gather on the quantized values when supported. | ||
| dequant_op_quant!( | ||
| ty Self, | ||
| float_op |tensor| B::float_gather(dim, tensor, indices), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Repeat the tensor along the given dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// * `dim` - The dimension to repeat. | ||
| /// * `times` - The number of times to repeat the dimension. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the given dimension repeated. | ||
| fn q_repeat_dim(tensor: QuantizedTensor<B>, dim: usize, times: usize) -> QuantizedTensor<B> { | ||
| dequant_op_quant!( | ||
| ty Self, | ||
| float_op |tensor| B::float_repeat_dim(tensor, dim, times), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Adds two tensors together. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left hand side tensor. | ||
| /// * `rhs` - The right hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of adding the two tensors together. | ||
| fn q_add(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |lhs, rhs| B::float_add(lhs, rhs), | ||
| lhs, | ||
| rhs | ||
| ) | ||
| } | ||
| /// Adds a scalar to a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left hand side tensor. | ||
| /// * `rhs` - The right hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of adding the scalar to the tensor. | ||
| fn q_add_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_add_scalar(tensor, rhs), | ||
| lhs | ||
| ) | ||
| } | ||
| /// Clamps a tensor under a minimum value. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to clamp. | ||
| /// * `min` - The minimum value. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The clamped tensor. | ||
| fn q_clamp_min(tensor: QuantizedTensor<B>, min: FloatElem<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_clamp_min(tensor, min), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Clamps a tensor over a maximum value. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to clamp. | ||
| /// * `max` - The maximum value. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The clamped tensor. | ||
| fn q_clamp_max(tensor: QuantizedTensor<B>, max: FloatElem<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_clamp_max(tensor, max), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Clamps a tensor between a minimum and maximum value. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to clamp. | ||
| /// * `min` - The minimum value. | ||
| /// * `max` - The maximum value. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The clamped tensor. | ||
| fn q_clamp( | ||
| tensor: QuantizedTensor<B>, | ||
| min: FloatElem<B>, | ||
| max: FloatElem<B>, | ||
| ) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_clamp(tensor, min, max), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Subtracts two tensors. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left hand side tensor. | ||
| /// * `rhs` - The right hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of subtracting the two tensors. | ||
| fn q_sub(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |lhs, rhs| B::float_sub(lhs, rhs), | ||
| lhs, | ||
| rhs | ||
| ) | ||
| } | ||
| /// Subtracts a scalar from a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left hand side tensor. | ||
| /// * `rhs` - The right hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of subtracting the scalar from the tensor. | ||
| fn q_sub_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_sub_scalar(tensor, rhs), | ||
| lhs | ||
| ) | ||
| } | ||
| /// Multiplies two tensors together element-wise. | ||
| fn q_mul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |lhs, rhs| B::float_mul(lhs, rhs), | ||
| lhs, | ||
| rhs | ||
| ) | ||
| } | ||
| /// Multiplies a tensor by a scalar. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left hand side tensor. | ||
| /// * `rhs` - The right hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of multiplying the tensor by the scalar. | ||
| fn q_mul_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_mul_scalar(tensor, rhs), | ||
| lhs | ||
| ) | ||
| } | ||
| /// Divides two tensors element-wise. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left hand side tensor. | ||
| /// * `rhs` - The right hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of dividing the two tensors. | ||
| fn q_div(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |lhs, rhs| B::float_div(lhs, rhs), | ||
| lhs, | ||
| rhs | ||
| ) | ||
| } | ||
| /// Divides a tensor by a scalar. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left hand side tensor. | ||
| /// * `rhs` - The right hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of dividing the tensor by the scalar. | ||
| fn q_div_scalar(lhs: QuantizedTensor<B>, rhs: FloatElem<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_div_scalar(tensor, rhs), | ||
| lhs | ||
| ) | ||
| } | ||
| /// Multiplies two tensors together using matrix multiplication. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left hand side tensor. | ||
| /// * `rhs` - The right hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of multiplying the two tensors together using matrix multiplication. | ||
| fn q_matmul(lhs: TensorPrimitive<B>, rhs: TensorPrimitive<B>) -> TensorPrimitive<B> { | ||
| let mut propagation = QuantPropagation::Inhibit; | ||
| let mut scheme = QuantScheme::default(); | ||
| let lhs = match lhs { | ||
| TensorPrimitive::Float(lhs) => lhs, | ||
| TensorPrimitive::QFloat(lhs) => { | ||
| propagation = lhs.propagation(); | ||
| scheme = *lhs.scheme(); | ||
| Self::dequantize(lhs) | ||
| } | ||
| }; | ||
| let rhs = match rhs { | ||
| TensorPrimitive::Float(rhs) => rhs, | ||
| TensorPrimitive::QFloat(rhs) => { | ||
| propagation = rhs.propagation(); | ||
| scheme = *rhs.scheme(); | ||
| Self::dequantize(rhs) | ||
| } | ||
| }; | ||
| let out_f = B::float_matmul(lhs, rhs); | ||
| match propagation { | ||
| QuantPropagation::Propagate => { | ||
| TensorPrimitive::QFloat(<Self>::quantize_dynamic(out_f, &scheme)) | ||
| } | ||
| QuantPropagation::Inhibit => TensorPrimitive::Float(out_f), | ||
| } | ||
| } | ||
| /// Negates a tensor element-wise. | ||
| fn q_neg(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_neg(tensor), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Calculates the reciprocals element-wise | ||
| fn q_recip(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_recip(tensor), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Sum of all elements in a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to sum. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A scalar tensor with the sum of all elements in `tensor`. | ||
| fn q_sum(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_sum(tensor), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Sum of all elements in a tensor along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to sum. | ||
| /// * `dim` - The dimension along which to sum. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the sum of all elements in `tensor` along `dim`. | ||
| fn q_sum_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_sum_dim(tensor, dim), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Product of all elements in a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to product. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A scalar tensor with the product of all elements in `tensor`. | ||
| fn q_prod(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_prod(tensor), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Product of all elements in a tensor along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to product. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the product of all elements in `tensor` along `dim`. | ||
| fn q_prod_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_prod_dim(tensor, dim), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Mean of all elements in a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to mean. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A scalar tensor with the mean of all elements in `tensor`. | ||
| fn q_mean(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_mean(tensor), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Mean of all elements in a tensor along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to mean. | ||
| /// * `dim` - The dimension along which to mean. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the mean of all elements in `tensor` along `dim`. | ||
| fn q_mean_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_mean_dim(tensor, dim), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Computes the cumulative sum of elements along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to compute the cumulative sum of. | ||
| /// * `dim` - The dimension along which to compute the cumulative sum. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape where each element is the cumulative sum | ||
| /// of all elements up to and including that position along the dimension. | ||
| fn q_cumsum(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_cumsum(tensor, dim), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Computes the cumulative product of elements along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to compute the cumulative product of. | ||
| /// * `dim` - The dimension along which to compute the cumulative product. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape where each element is the cumulative product | ||
| /// of all elements up to and including that position along the dimension. | ||
| fn q_cumprod(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_cumprod(tensor, dim), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Computes the cumulative minimum of elements along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to compute the cumulative minimum of. | ||
| /// * `dim` - The dimension along which to compute the cumulative minimum. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape where each element is the minimum | ||
| /// of all elements up to and including that position along the dimension. | ||
| fn q_cummin(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_cummin(tensor, dim), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Computes the cumulative maximum of elements along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to compute the cumulative maximum of. | ||
| /// * `dim` - The dimension along which to compute the cumulative maximum. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape where each element is the maximum | ||
| /// of all elements up to and including that position along the dimension. | ||
| fn q_cummax(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_cummax(tensor, dim), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Returns a new tensor with exponential values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to exponentiate. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with exponential values. | ||
| fn q_exp(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_exp(tensor), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Returns a new tensor with natural logarithm values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to take the logarithm of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with natural logarithm values. | ||
| fn q_log(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_log(tensor), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Returns a new tensor with logarithm values of (1 + Xi). | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to take the logarithm of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi). | ||
| fn q_log1p(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_log1p(tensor), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Element-wise power with another tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left hand side tensor. | ||
| /// * `rhs` - The right hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The elements of `lhs` raised to the power of the elements of `rhs`. | ||
| fn q_powf(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |lhs, rhs| B::float_powf(lhs, rhs), | ||
| lhs, | ||
| rhs | ||
| ) | ||
| } | ||
| /// Element-wise power with an IntTensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left hand side tensor. | ||
| /// * `rhs` - The right hand side floatTensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor. | ||
| fn q_powi(lhs: QuantizedTensor<B>, rhs: IntTensor<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_powi(tensor, rhs), | ||
| lhs | ||
| ) | ||
| } | ||
| /// Element-wise power with an int scalar. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left hand side tensor. | ||
| /// * `rhs` - The right hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The elements of `lhs` raised to the value of `rhs`. | ||
| fn q_powi_scalar(lhs: QuantizedTensor<B>, rhs: IntElem<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_powi_scalar(tensor, rhs), | ||
| lhs | ||
| ) | ||
| } | ||
| /// Element-wise power with a float scalar. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to exponentiate. | ||
| /// * `value` - The exponent. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with values raised to the power of `value`. | ||
| fn q_powf_scalar(tensor: QuantizedTensor<B>, value: f32) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_powf_scalar(tensor, value), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Returns a new tensor with square root values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to take the square root of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with square root values. | ||
| fn q_sqrt(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_sqrt(tensor), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Returns a new tensor with absolute values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to take absolute value of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with absolute values. | ||
| fn q_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> { | ||
| dequant_op_quant!( | ||
| ty Self, | ||
| float_op |tensor| B::float_abs(tensor), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Returns a new tensor with cosine values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to take the cosine of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with cosine values. | ||
| fn q_cos(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_cos(tensor), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Returns a new tensor with sine values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to take the sine of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with sine values. | ||
| fn q_sin(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_sin(tensor), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Returns a new tensor with tangent values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to take the tangent of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with tangent values. | ||
| fn q_tan(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_tan(tensor), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Returns a new tensor with hyperbolic cosine values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to take the hyperbolic cosine of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with hyperbolic cosine values. | ||
| fn q_cosh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_cosh(tensor), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Returns a new tensor with hyperbolic sine values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to take the hyperbolic sine of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with hyperbolic sine values. | ||
| fn q_sinh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_sinh(tensor), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Returns a new tensor with hyperbolic tangent values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to take the hyperbolic tangent of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with hyperbolic tangent values. | ||
| fn q_tanh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_tanh(tensor), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Returns a new tensor with the error function values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to take the error function of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with error function values. | ||
| fn q_erf(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> { | ||
| dequant_op_flow!( | ||
| ty Self, | ||
| float_op |tensor| B::float_erf(tensor), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Concatenates tensors along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensors` - The tensors to concatenate. | ||
| /// * `dim` - The dimension along which to concatenate. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the concatenated tensors along `dim`. | ||
| fn q_cat(tensors: Vec<QuantizedTensor<B>>, dim: usize) -> QuantizedTensor<B> { | ||
| // Heuristic: prioritize first tensor scheme | ||
| let scheme = *tensors.first().unwrap().scheme(); | ||
| let tensor_f = tensors | ||
| .into_iter() | ||
| .map(|tensor| Self::dequantize(tensor)) | ||
| .collect(); | ||
| let out_f = B::float_cat(tensor_f, dim); | ||
| Self::quantize_dynamic(out_f, &scheme) | ||
| } | ||
| /// Gets the indices of the maximum elements of a tensor along an axis. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the maximum elements of. | ||
| /// * `dim` - The dimension along which to get the maximum elements. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the indices of the maximum elements of `tensor` along `dim`. | ||
| fn q_argmax(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> { | ||
| // Default implementation. Backends can sort on the int values since qparams remain the same. | ||
| let tensor_f = Self::dequantize(tensor); | ||
| B::float_argmax(tensor_f, dim) | ||
| } | ||
| /// Gets the indices of the minimum elements of a tensor along an axis. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the minimum elements of. | ||
| /// * `dim` - The dimension along which to get the minimum elements. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the indices of the minimum elements of `tensor` along `dim`. | ||
| fn q_argmin(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> { | ||
| // Default implementation. Backends can sort on the int values since qparams remain the same. | ||
| let tensor_f = Self::dequantize(tensor); | ||
| B::float_argmin(tensor_f, dim) | ||
| } | ||
| /// Gets the maximum element of a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the maximum elements of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the maximum element of `tensor`. | ||
| fn q_max(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> { | ||
| let shape = tensor.shape(); | ||
| let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()])); | ||
| B::q_max_dim(tensor, 0) | ||
| } | ||
| /// Gets the maximum elements of a tensor along an axis. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the maximum elements of. | ||
| /// * `dim` - The dimension along which to get the maximum elements. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the maximum elements of `tensor` along `dim`. | ||
| fn q_max_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> { | ||
| let index = B::q_argmax(tensor.clone(), dim); | ||
| B::q_gather(dim, tensor, index) | ||
| } | ||
| /// Gets the maximum elements of a tensor along an axis and their indices. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the maximum elements of. | ||
| /// * `dim` - The dimension along which to get the maximum elements. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tuple with the maximum elements of `tensor` along `dim` and their indices. | ||
| fn q_max_dim_with_indices( | ||
| tensor: QuantizedTensor<B>, | ||
| dim: usize, | ||
| ) -> (QuantizedTensor<B>, IntTensor<B>) { | ||
| let index = B::q_argmax(tensor.clone(), dim); | ||
| let values = B::q_gather(dim, tensor, index.clone()); | ||
| (values, index) | ||
| } | ||
| /// Gets the minimum element of a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the minimum elements of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the minimum element of `tensor`. | ||
| fn q_min(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> { | ||
| let shape = tensor.shape(); | ||
| let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()])); | ||
| B::q_min_dim(tensor, 0) | ||
| } | ||
| /// Gets the minimum elements of a tensor along an axis. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the minimum elements of. | ||
| /// * `dim` - The dimension along which to get the minimum elements. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the minimum elements of `tensor` along `dim`. | ||
| fn q_min_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> { | ||
| let index = B::q_argmin(tensor.clone(), dim); | ||
| B::q_gather(dim, tensor, index) | ||
| } | ||
| /// Gets the minimum elements of a tensor along an axis and their indices. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the minimum elements of. | ||
| /// * `dim` - The dimension along which to get the minimum elements. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tuple with the minimum elements of `tensor` along `dim` and their indices. | ||
| fn q_min_dim_with_indices( | ||
| tensor: QuantizedTensor<B>, | ||
| dim: usize, | ||
| ) -> (QuantizedTensor<B>, IntTensor<B>) { | ||
| let index = B::q_argmin(tensor.clone(), dim); | ||
| let values = B::q_gather(dim, tensor, index.clone()); | ||
| (values, index) | ||
| } | ||
| /// Gets the maximum element of a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the maximum elements of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the maximum element of `tensor`. | ||
| fn q_max_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> { | ||
| let shape = tensor.shape(); | ||
| let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()])); | ||
| B::q_max_abs_dim(tensor, 0) | ||
| } | ||
| /// Gets the maximum elements of a tensor along an axis. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the maximum elements of. | ||
| /// * `dim` - The dimension along which to get the maximum elements. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the maximum elements of `tensor` along `dim`. | ||
| fn q_max_abs_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> { | ||
| let index = B::q_argmax(B::q_abs(tensor.clone()), dim); | ||
| B::q_gather(dim, tensor, index) | ||
| } | ||
| /// Tests if any element in the `tensor` evaluates to True. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to test. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise. | ||
| fn q_any(tensor: QuantizedTensor<B>) -> BoolTensor<B> { | ||
| let tensor_f = Self::dequantize(tensor); | ||
| B::float_any(tensor_f) | ||
| } | ||
| /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to test. | ||
| /// * `dim` - The axis along which to test. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis | ||
| /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the | ||
| /// input evaluates to True, False otherwise. | ||
| fn q_any_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> { | ||
| let tensor_f = Self::dequantize(tensor); | ||
| B::float_any_dim(tensor_f, dim) | ||
| } | ||
| /// Tests if all elements in the `tensor` evaluate to True. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to test. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor | ||
| /// evaluate to True, False otherwise. | ||
| fn q_all(tensor: QuantizedTensor<B>) -> BoolTensor<B> { | ||
| let tensor_f = Self::dequantize(tensor); | ||
| B::float_all(tensor_f) | ||
| } | ||
| /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to test. | ||
| /// * `dim` - The axis along which to test. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis | ||
| /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input | ||
| /// evaluates to True, False otherwise. | ||
| fn q_all_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> { | ||
| let tensor_f = Self::dequantize(tensor); | ||
| B::float_all_dim(tensor_f, dim) | ||
| } | ||
| /// Sort the elements of the input `tensor` by value in along a given dimension. | ||
| /// | ||
| /// This sort is unstable (i.e., may reorder equal elements). | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The input tensor. | ||
| /// * `dim` - The axis along which to sort. | ||
| /// * `descending` - The sorting order. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as the input tensor, where the elements are sorted by value. | ||
| fn q_sort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> QuantizedTensor<B> { | ||
| // Default implementation. Backends can sort on the int values since qparams remain the same. | ||
| dequant_op_quant!( | ||
| ty Self, | ||
| float_op |tensor| B::float_sort(tensor, dim, descending), | ||
| tensor | ||
| ) | ||
| } | ||
| /// Sort the elements of the input `tensor` by value in along a given dimension. | ||
| /// | ||
| /// This sort is unstable (i.e., may reorder equal elements). | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The input tensor. | ||
| /// * `dim` - The axis along which to sort. | ||
| /// * `descending` - The sorting order. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as the input tensor and corresponding indices, where | ||
| /// the elements are sorted by value and the indices map back to the original input tensor. | ||
| fn q_sort_with_indices( | ||
| tensor: QuantizedTensor<B>, | ||
| dim: usize, | ||
| descending: bool, | ||
| ) -> (QuantizedTensor<B>, IntTensor<B>) { | ||
| // Default implementation. Backends can sort on the int values since qparams remain the same. | ||
| let scheme = *tensor.scheme(); | ||
| let tensor_f = Self::dequantize(tensor); | ||
| let (out_f, indices) = B::float_sort_with_indices(tensor_f, dim, descending); | ||
| (Self::quantize_dynamic(out_f, &scheme), indices) | ||
| } | ||
| /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension. | ||
| /// | ||
| /// This sort is unstable (i.e., may reorder equal elements). | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The input tensor. | ||
| /// * `dim` - The axis along which to sort. | ||
| /// * `descending` - The sorting order. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as the input tensor the indices map back to the original input tensor. | ||
| fn q_argsort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> IntTensor<B> { | ||
| // Default implementation. Backends can sort on the int values since qparams remain the same. | ||
| let tensor_f = Self::dequantize(tensor); | ||
| B::float_argsort(tensor_f, dim, descending) | ||
| } | ||
| } |
| use super::cat::cat_with_slice_assign; | ||
| use super::grid_sample::float_grid_sample_2d_ref; | ||
| use super::repeat_dim::repeat_with_slice_assign; | ||
| use super::{BoolTensor, Device, FloatElem, FloatTensor, IntElem, IntTensor}; | ||
| use crate::backend::ExecutionError; | ||
| use crate::ops::GridSampleOptions; | ||
| use crate::{Distribution, ElementConversion, Float, TensorData, backend::Backend, tensor::Shape}; | ||
| use crate::{FloatDType, TensorMetadata, TensorPrimitive}; | ||
| use crate::{argsort, sort, sort_with_indices}; | ||
| use alloc::vec::Vec; | ||
| /// Operations on float tensors. | ||
| pub trait FloatTensorOps<B: Backend> { | ||
| /// Creates a new tensor from the data structure. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `data` - The data structure. | ||
| /// * `device` - The device to create the tensor on. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the given data. | ||
| fn float_from_data(data: TensorData, device: &Device<B>) -> FloatTensor<B>; | ||
| /// Creates a new tensor with random values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `shape` - The shape of the tensor. | ||
| /// * `distribution` - The distribution to sample from. | ||
| /// * `device` - The device to create the tensor on. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the given shape and random values. | ||
| fn float_random(shape: Shape, distribution: Distribution, device: &Device<B>) | ||
| -> FloatTensor<B>; | ||
| /// Creates a new tensor with zeros. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `shape` - The shape of the tensor. | ||
| /// * `device` - The device to create the tensor on. | ||
| /// * `dtype` - The target data type. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the given shape and zeros. | ||
| fn float_zeros(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B> { | ||
| Self::float_from_data(TensorData::full_dtype(shape, 0, dtype.into()), device) | ||
| } | ||
| /// Creates a new tensor with ones. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `shape` - The shape of the tensor. | ||
| /// * `device` - The device to create the tensor on. | ||
| /// * `dtype` - The target data type. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the given shape and ones. | ||
| fn float_ones(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B> { | ||
| Self::float_from_data(TensorData::full_dtype(shape, 1, dtype.into()), device) | ||
| } | ||
| /// Creates a tensor filled with given value. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `shape` - The shape of the tensor. | ||
| /// * `fill_value` - The value with which to fill the tensor. | ||
| /// * `device` - The device to create the tensor on. | ||
| /// * `dtype` - The target data type. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor filled with given value | ||
| fn float_full( | ||
| shape: Shape, | ||
| fill_value: FloatElem<B>, | ||
| device: &Device<B>, | ||
| dtype: FloatDType, | ||
| ) -> FloatTensor<B> { | ||
| Self::float_from_data( | ||
| TensorData::full_dtype(shape, fill_value, dtype.into()), | ||
| device, | ||
| ) | ||
| } | ||
| /// Converts the tensor to a data structure. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The data structure with the tensor's data. | ||
| fn float_into_data( | ||
| tensor: FloatTensor<B>, | ||
| ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send; | ||
| /// Gets the device of the tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The device of the tensor. | ||
| fn float_device(tensor: &FloatTensor<B>) -> Device<B>; | ||
| /// Moves the tensor to the given device. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// * `device` - The device to move the tensor to. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor on the given device. | ||
| fn float_to_device(tensor: FloatTensor<B>, device: &Device<B>) -> FloatTensor<B>; | ||
| /// Converts float tensor to int tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The int tensor with the same data as the float tensor. | ||
| fn float_into_int(tensor: FloatTensor<B>) -> IntTensor<B>; | ||
| /// Creates an empty tensor with the given shape. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `shape` - The shape of the tensor. | ||
| /// * `device` - The device to create the tensor on. | ||
| /// * `dtype` - The target data type. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The empty tensor with the given shape. | ||
| fn float_empty(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B>; | ||
| /// Repeat the tensor along the given dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor. | ||
| /// * `dim` - The dimension to repeat. | ||
| /// * `times` - The number of times to repeat the dimension. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the given dimension repeated. | ||
| fn float_repeat_dim(tensor: FloatTensor<B>, dim: usize, times: usize) -> FloatTensor<B> { | ||
| repeat_with_slice_assign::<B, Float>(TensorPrimitive::Float(tensor), dim, times).tensor() | ||
| } | ||
| /// Adds two tensors together. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of adding the two tensors together. | ||
| fn float_add(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>; | ||
| /// Adds a scalar to a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of adding the scalar to the tensor. | ||
| fn float_add_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>; | ||
| /// Clamps a tensor under a minimum value. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to clamp. | ||
| /// * `min` - The minimum value. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The clamped tensor. | ||
| fn float_clamp_min(tensor: FloatTensor<B>, min: FloatElem<B>) -> FloatTensor<B> { | ||
| // Default implementation | ||
| let mask = Self::float_lower_elem(tensor.clone(), min); | ||
| B::float_mask_fill(tensor, mask, min) | ||
| } | ||
| /// Clamps a tensor over a maximum value. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to clamp. | ||
| /// * `max` - The maximum value. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The clamped tensor. | ||
| fn float_clamp_max(tensor: FloatTensor<B>, max: FloatElem<B>) -> FloatTensor<B> { | ||
| // Default implementation | ||
| let mask = Self::float_greater_elem(tensor.clone(), max); | ||
| B::float_mask_fill(tensor, mask, max) | ||
| } | ||
| /// Clamps a tensor between a minimum and maximum value. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to clamp. | ||
| /// * `min` - The minimum value. | ||
| /// * `max` - The maximum value. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The clamped tensor. | ||
| fn float_clamp(tensor: FloatTensor<B>, min: FloatElem<B>, max: FloatElem<B>) -> FloatTensor<B> { | ||
| // Default implementation | ||
| Self::float_clamp_min(Self::float_clamp_max(tensor, max), min) | ||
| } | ||
| /// Subtracts two tensors. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of subtracting the two tensors. | ||
| fn float_sub(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>; | ||
| /// Subtracts a scalar from a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of subtracting the scalar from the tensor. | ||
| fn float_sub_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>; | ||
| /// Multiplies two tensors together element-wise. | ||
| fn float_mul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>; | ||
| /// Multiplies a tensor by a scalar. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of multiplying the tensor by the scalar. | ||
| fn float_mul_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>; | ||
| /// Divides two tensors element-wise. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of dividing the two tensors. | ||
| fn float_div(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>; | ||
| /// Divides a tensor by a scalar. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of dividing the tensor by the scalar. | ||
| fn float_div_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>; | ||
| /// Computes the remainder of division between two tensors element-wise. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The element-wise remainder when dividing `lhs` by `rhs`. | ||
| fn float_remainder(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>; | ||
| /// Computes the modulus of a tensor given a scalar. | ||
| /// | ||
| /// # Arguments | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of applying the modulus of the scalar to the tensor. | ||
| fn float_remainder_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>; | ||
| /// Multiplies two tensors together using matrix multiplication. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The result of multiplying the two tensors together using matrix multiplication. | ||
| fn float_matmul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>; | ||
| /// Computes the cross product of two tensors along a given dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side tensor. | ||
| /// * `dim` - The dimension to compute the cross product along. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The cross product of the two tensors. | ||
| fn float_cross(lhs: FloatTensor<B>, rhs: FloatTensor<B>, dim: usize) -> FloatTensor<B>; | ||
| /// Negates a tensor element-wise. | ||
| fn float_neg(tensor: FloatTensor<B>) -> FloatTensor<B> { | ||
| Self::float_mul_scalar(tensor, (-1.0_f32).elem::<FloatElem<B>>()) | ||
| } | ||
| /// Calculates the reciprocals element-wise | ||
| fn float_recip(tensor: FloatTensor<B>) -> FloatTensor<B>; | ||
| /// Transposes a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to transpose. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The transposed tensor. | ||
| fn float_transpose(tensor: FloatTensor<B>) -> FloatTensor<B> { | ||
| let ndims = tensor.shape().num_dims(); | ||
| Self::float_swap_dims(tensor, ndims - 2, ndims - 1) | ||
| } | ||
| /// Swaps two dimensions of a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to swap the dimensions of. | ||
| /// * `dim1` - The first dimension to swap. | ||
| /// * `dim2` - The second dimension to swap. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the dimensions swapped. | ||
| fn float_swap_dims(tensor: FloatTensor<B>, dim1: usize, dim2: usize) -> FloatTensor<B>; | ||
| /// Permutes the dimensions of a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to permute the dimensions of. | ||
| /// * `axes` - The new order of the dimensions. | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the dimensions permuted. | ||
| fn float_permute(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>; | ||
| /// Reverse the order of elements in a tensor along the given axes. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to reverse. | ||
| /// * `axes` - The axes to reverse. | ||
| /// | ||
| /// The tensor with the elements reversed. | ||
| fn float_flip(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>; | ||
| /// Reshapes a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to reshape. | ||
| /// * `shape` - The new shape of the tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the new shape. | ||
| fn float_reshape(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>; | ||
| /// Gather elements from a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `dim` - The dimension to gather from. | ||
| /// * `tensor` - The tensor to gather from. | ||
| /// * `indices` - The indices to gather. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The gathered elements. | ||
| fn float_gather(dim: usize, tensor: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B>; | ||
| /// Scatter elements into a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `dim` - The dimension to scatter into. | ||
| /// * `tensor` - The tensor to scatter into. | ||
| /// * `indices` - The indices to scatter into. | ||
| /// * `value` - The value to scatter. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the scattered elements. | ||
| fn float_scatter( | ||
| dim: usize, | ||
| tensor: FloatTensor<B>, | ||
| indices: IntTensor<B>, | ||
| value: FloatTensor<B>, | ||
| ) -> FloatTensor<B>; | ||
| /// Select tensor elements along the given dimension corresponding for the given indices. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to select from. | ||
| /// * `dim` - The dimension to select from. | ||
| /// * `indices` - The indices to select. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The selected elements. | ||
| fn float_select(tensor: FloatTensor<B>, dim: usize, indices: IntTensor<B>) -> FloatTensor<B>; | ||
| /// Assign the selected elements along the given dimension corresponding for the given indices | ||
| /// to the given value. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to select from. | ||
| /// * `dim` - The dimension to select from. | ||
| /// * `indices` - The indices to select. | ||
| /// * `value` - The value to assign. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the selected elements assigned to the given value. | ||
| fn float_select_assign( | ||
| tensor: FloatTensor<B>, | ||
| dim: usize, | ||
| indices: IntTensor<B>, | ||
| value: FloatTensor<B>, | ||
| ) -> FloatTensor<B>; | ||
| /// Select tensor elements corresponding to the given slices. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to select from. | ||
| /// * `slices` - The slices specifying ranges and steps for each dimension. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The selected elements in a new tensor. | ||
| /// | ||
| /// # Note | ||
| /// | ||
| /// Empty slices (where start >= end) are handled at the high-level tensor API and will not | ||
| /// be passed to this method. Backend implementations do not need to handle empty slices. | ||
| fn float_slice(tensor: FloatTensor<B>, slices: &[crate::Slice]) -> FloatTensor<B>; | ||
| /// Assign the selected elements corresponding to the given slices to the given value. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to select from. | ||
| /// * `ranges` - The ranges to select. | ||
| /// * `value` - The value to assign. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the selected elements assigned to the given value. | ||
| fn float_slice_assign( | ||
| tensor: FloatTensor<B>, | ||
| slices: &[crate::Slice], | ||
| value: FloatTensor<B>, | ||
| ) -> FloatTensor<B>; | ||
| /// Update the given tensor with the value tensor where the mask is true. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to select from. | ||
| /// * `mask` - The boolean mask to select with. | ||
| /// * `value` - The value to assign to the selected elements from the value tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the selected elements assigned to the given value. | ||
| fn float_mask_where( | ||
| tensor: FloatTensor<B>, | ||
| mask: BoolTensor<B>, | ||
| value: FloatTensor<B>, | ||
| ) -> FloatTensor<B>; | ||
| /// Update the given tensor with the value where the mask is true. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to select from. | ||
| /// * `mask` - The boolean mask to select with. | ||
| /// * `value` - The value to assign to the selected elements. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The tensor with the selected elements assigned to the given value. | ||
| fn float_mask_fill( | ||
| tensor: FloatTensor<B>, | ||
| mask: BoolTensor<B>, | ||
| value: FloatElem<B>, | ||
| ) -> FloatTensor<B>; | ||
| /// Equal comparison of two tensors. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor with the result of the comparison. | ||
| fn float_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>; | ||
| /// Element-wise non-equality comparison. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor with the result of the comparison. | ||
| fn float_not_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B> { | ||
| let equal_tensor = B::float_equal(lhs, rhs); | ||
| B::bool_not(equal_tensor) | ||
| } | ||
| /// Equal comparison of a tensor and a scalar. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor with the result of the comparison. | ||
| fn float_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>; | ||
| /// Element-wise non-equality comparison with a scalar. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor with the result of the comparison. | ||
| fn float_not_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B> { | ||
| let equal_tensor = B::float_equal_elem(lhs, rhs); | ||
| B::bool_not(equal_tensor) | ||
| } | ||
| /// Greater than comparison of two tensors. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor with the result of the comparison. | ||
| fn float_greater(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>; | ||
| /// Greater than comparison of a tensor and a scalar. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor with the result of the comparison. | ||
| fn float_greater_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>; | ||
| /// Greater than or equal comparison of two tensors. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor with the result of the comparison. | ||
| fn float_greater_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>; | ||
| /// Greater than or equal comparison of a tensor and a scalar. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor with the result of the comparison. | ||
| fn float_greater_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>; | ||
| /// Less than comparison of two tensors. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor with the result of the comparison. | ||
| fn float_lower(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>; | ||
| /// Less than comparison of a tensor and a scalar. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor with the result of the comparison. | ||
| fn float_lower_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>; | ||
| /// Less than or equal comparison of two tensors. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor with the result of the comparison. | ||
| fn float_lower_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>; | ||
| /// Less than or equal comparison of a tensor and a scalar. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor with the result of the comparison. | ||
| fn float_lower_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>; | ||
| /// Detaches a tensor from the computation graph. | ||
| fn float_detach(tensor: FloatTensor<B>) -> FloatTensor<B> { | ||
| // Should only be overridden by autodiff backends. | ||
| tensor | ||
| } | ||
| /// Sets the `require_grad` flag of a tensor. | ||
| fn float_set_require_grad(tensor: FloatTensor<B>, _require_grad: bool) -> FloatTensor<B> { | ||
| // Should only be overridden by autodiff backends. | ||
| tensor | ||
| } | ||
| /// Returns the `require_grad` flag of a tensor. | ||
| fn float_is_require_grad(_tensor: &FloatTensor<B>) -> bool { | ||
| // Should only be overridden by autodiff backends. | ||
| false | ||
| } | ||
| /// Sum of all elements in a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to sum. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A scalar tensor with the sum of all elements in `tensor`. | ||
| fn float_sum(tensor: FloatTensor<B>) -> FloatTensor<B>; | ||
| /// Sum of all elements in a tensor along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to sum. | ||
| /// * `dim` - The dimension along which to sum. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the sum of all elements in `tensor` along `dim`. | ||
| fn float_sum_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>; | ||
| /// Product of all elements in a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to product. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A scalar tensor with the product of all elements in `tensor`. | ||
| fn float_prod(tensor: FloatTensor<B>) -> FloatTensor<B> { | ||
| // Product of all elements in a tensor | ||
| B::float_exp(B::float_sum(B::float_log(tensor))) | ||
| } | ||
| /// Product of all elements in a tensor along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to product. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the product of all elements in `tensor` along `dim`. | ||
| fn float_prod_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> { | ||
| // Product of all elements in a tensor along a dimension | ||
| B::float_exp(B::float_sum_dim(B::float_log(tensor), dim)) | ||
| } | ||
| /// Mean of all elements in a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to mean. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A scalar tensor with the mean of all elements in `tensor`. | ||
| fn float_mean(tensor: FloatTensor<B>) -> FloatTensor<B> { | ||
| let num_elems = tensor.shape().num_elements(); | ||
| B::float_div_scalar(B::float_sum(tensor), (num_elems as i64).elem()) | ||
| } | ||
| /// Mean of all elements in a tensor along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to mean. | ||
| /// * `dim` - The dimension along which to mean. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the mean of all elements in `tensor` along `dim`. | ||
| fn float_mean_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>; | ||
| /// Computes the cumulative sum of elements along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to compute the cumulative sum of. | ||
| /// * `dim` - The dimension along which to compute the cumulative sum. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape where each element is the cumulative sum | ||
| /// of all elements up to and including that position along the dimension. | ||
| fn float_cumsum(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>; | ||
| /// Computes the cumulative product of elements along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to compute the cumulative product of. | ||
| /// * `dim` - The dimension along which to compute the cumulative product. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape where each element is the cumulative product | ||
| /// of all elements up to and including that position along the dimension. | ||
| fn float_cumprod(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>; | ||
| /// Computes the cumulative minimum of elements along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to compute the cumulative minimum of. | ||
| /// * `dim` - The dimension along which to compute the cumulative minimum. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape where each element is the minimum | ||
| /// of all elements up to and including that position along the dimension. | ||
| fn float_cummin(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>; | ||
| /// Computes the cumulative maximum of elements along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to compute the cumulative maximum of. | ||
| /// * `dim` - The dimension along which to compute the cumulative maximum. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape where each element is the maximum | ||
| /// of all elements up to and including that position along the dimension. | ||
| fn float_cummax(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>; | ||
| /// Converts a tensor to another floating point data type. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to convert. | ||
| /// * `dtype` - The target data type. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same values as `tensor` but in the target floating point data type. | ||
| fn float_cast(tensor: FloatTensor<B>, dtype: FloatDType) -> FloatTensor<B>; | ||
| /// Returns a new tensor with exponential values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to exponentiate. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with exponential values. | ||
| fn float_exp(tensor: FloatTensor<B>) -> FloatTensor<B>; | ||
| /// Returns a new tensor with natural logarithm values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to take the logarithm of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with natural logarithm values. | ||
| fn float_log(tensor: FloatTensor<B>) -> FloatTensor<B>; | ||
| /// Returns a new tensor with logarithm values of (1 + Xi). | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to take the logarithm of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi). | ||
| fn float_log1p(tensor: FloatTensor<B>) -> FloatTensor<B>; | ||
| /// Element-wise power with a FloatTensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The elements of `lhs` raised to the power of the elements of `rhs`. | ||
| fn float_powf(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>; | ||
| /// Element-wise power with an IntTensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side floatTensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor. | ||
| fn float_powi(lhs: FloatTensor<B>, rhs: IntTensor<B>) -> FloatTensor<B> { | ||
| Self::float_powf(lhs, B::int_into_float(rhs)) | ||
| } | ||
| /// Raises a tensor to the power of an int scalar. | ||
| /// | ||
| /// # Backend Implementors Note | ||
| /// | ||
| /// A number of common exponent cases can be implemented with operations | ||
| /// which are much cheaper than generic exponentiation. | ||
| /// | ||
| /// This (`Backend` impl overridable) operation handles generic optimizations | ||
| /// for several common integer exponent cases; and then dispatches to | ||
| /// the (`Backend` impl overridable) [`Self::float_powi_scalar_impl`] | ||
| /// operation to handle the generic case. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The elements of `lhs` raised to the value of `rhs`. | ||
| fn float_powi_scalar(lhs: FloatTensor<B>, rhs: IntElem<B>) -> FloatTensor<B> { | ||
| let exp = rhs.elem::<i32>(); | ||
| match exp { | ||
| 0 => Self::float_ones(lhs.shape(), &B::float_device(&lhs), lhs.dtype().into()), | ||
| 1 => lhs, | ||
| 2 => B::float_mul(lhs.clone(), lhs), | ||
| -1 => Self::float_recip(lhs), | ||
| -2 => Self::float_recip(B::float_mul(lhs.clone(), lhs)), | ||
| _ => Self::float_powi_scalar_impl(lhs, rhs), | ||
| } | ||
| } | ||
| /// Raises a tensor to the power of an int scalar. | ||
| /// | ||
| /// # Backend Implementors Note | ||
| /// | ||
| /// This is the generic implementation of integer exponentiation | ||
| /// called by [`Self::float_powi_scalar`] in the fallback case. | ||
| /// | ||
| /// As a general rule, this should not be called directly. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `lhs` - The left-hand side tensor. | ||
| /// * `rhs` - The right-hand side scalar. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The elements of `lhs` raised to the value of `rhs`. | ||
| fn float_powi_scalar_impl(lhs: FloatTensor<B>, rhs: IntElem<B>) -> FloatTensor<B> { | ||
| // Avoid a recursive loop by deferring directly to float_powf_scalar_impl. | ||
| Self::float_powf_scalar_impl(lhs, rhs.elem::<f32>()) | ||
| } | ||
| /// Returns a new tensor with values raised to the power of float `value`. | ||
| /// | ||
| /// # Backend Implementors Note | ||
| /// | ||
| /// This (`Backend` impl overridable) operation dispatches integer exponentiation | ||
| /// to [`Self::float_powi_scalar`], and the remaining non-integer exponent cases to | ||
| /// the (`Backend` impl overridable) [`Self::float_powf_scalar_impl`] | ||
| /// operation to handle the generic case. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to exponentiate. | ||
| /// * `value` - The exponent. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with values raised to the power of `value`. | ||
| fn float_powf_scalar(tensor: FloatTensor<B>, value: f32) -> FloatTensor<B> { | ||
| if num_traits::Float::floor(value) == value { | ||
| // When the exponent is an integer, use the integer exponentiation implementation. | ||
| let exp = B::IntElem::from_elem(value as i32); | ||
| Self::float_powi_scalar(tensor, exp) | ||
| } else { | ||
| Self::float_powf_scalar_impl(tensor, value) | ||
| } | ||
| } | ||
| /// Returns a new tensor with values raised to the power of float `value`. | ||
| /// | ||
| /// # Backend Implementors Note | ||
| /// | ||
| /// This is the generic implementation of integer exponentiation | ||
| /// called by [`Self::float_powf_scalar`] in the fallback case. | ||
| /// | ||
| /// This is the minimal required support a `Backend` must implement | ||
| /// for exponentiation. | ||
| /// | ||
| /// As a general rule, this should not be called directly. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to exponentiate. | ||
| /// * `value` - The exponent. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with values raised to the power of `value`. | ||
| fn float_powf_scalar_impl(tensor: FloatTensor<B>, value: f32) -> FloatTensor<B>; | ||
| /// Returns a new tensor with square root values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to take the square root of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with square root values. | ||
| fn float_sqrt(tensor: FloatTensor<B>) -> FloatTensor<B>; | ||
| /// Returns a new tensor with absolute values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to take absolute value of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with absolute values. | ||
| fn float_abs(tensor: FloatTensor<B>) -> FloatTensor<B>; | ||
| /// Returns a new tensor with cosine values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to take the cosine of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with cosine values. | ||
| fn float_cos(tensor: FloatTensor<B>) -> FloatTensor<B>; | ||
| /// Returns a new tensor with sine values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to take the sine of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with sine values. | ||
| fn float_sin(tensor: FloatTensor<B>) -> FloatTensor<B>; | ||
| /// Returns a new tensor with tangent values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to take the tangent of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with tangent values. | ||
| fn float_tan(tensor: FloatTensor<B>) -> FloatTensor<B> { | ||
| let sin = B::float_sin(tensor.clone()); | ||
| let cos = B::float_cos(tensor); | ||
| B::float_div(sin, cos) | ||
| } | ||
| /// Returns a new tensor with hyperbolic cosine values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to take the hyperbolic cosine of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with hyperbolic cosine values. | ||
| fn float_cosh(tensor: FloatTensor<B>) -> FloatTensor<B> { | ||
| // cosh = ( e^x + e^(-x) ) / 2 | ||
| let e_x = B::float_exp(tensor.clone()); | ||
| let e_neg_x = B::float_exp(B::float_neg(tensor)); | ||
| let num = B::float_add(e_x, e_neg_x); // e^x + e^(-x) | ||
| B::float_div_scalar(num, 2.0.elem()) | ||
| } | ||
| /// Returns a new tensor with hyperbolic sine values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to take the hyperbolic sine of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with hyperbolic sine values. | ||
| fn float_sinh(tensor: FloatTensor<B>) -> FloatTensor<B> { | ||
| // sinh = ( e^x - e^(-x) ) / 2 | ||
| let e_x = B::float_exp(tensor.clone()); | ||
| let e_neg_x = B::float_exp(B::float_neg(tensor)); | ||
| let num = B::float_sub(e_x, e_neg_x); // e^x - e^(-x) | ||
| B::float_div_scalar(num, 2.0.elem()) | ||
| } | ||
| /// Returns a new tensor with hyperbolic tangent values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to take the hyperbolic tangent of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with hyperbolic tangent values. | ||
| fn float_tanh(tensor: FloatTensor<B>) -> FloatTensor<B> { | ||
| let sinh = B::float_sinh(tensor.clone()); | ||
| let cosh = B::float_cosh(tensor); | ||
| B::float_div(sinh, cosh) | ||
| } | ||
| /// Returns a new tensor with rounded values. | ||
| /// | ||
| /// This function should implement the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even) | ||
| /// strategy, with halfway cases rounded to the nearest even integer value. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to be rounded. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with rounded values. | ||
| fn float_round(tensor: FloatTensor<B>) -> FloatTensor<B>; | ||
| /// Returns a new tensor with floored values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to be floored. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with floored values. | ||
| fn float_floor(tensor: FloatTensor<B>) -> FloatTensor<B>; | ||
| /// Returns a new tensor with ceiled values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to be ceiled. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with ceiled values. | ||
| fn float_ceil(tensor: FloatTensor<B>) -> FloatTensor<B>; | ||
| /// Returns a new tensor with truncated values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to be truncated. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with truncated values. | ||
| fn float_trunc(tensor: FloatTensor<B>) -> FloatTensor<B>; | ||
| /// Returns a new tensor with the error function values. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to take the error function of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` with error function values. | ||
| fn float_erf(tensor: FloatTensor<B>) -> FloatTensor<B>; | ||
| /// Concatenates tensors along a dimension. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensors` - The tensors to concatenate. | ||
| /// * `dim` - The dimension along which to concatenate. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the concatenated tensors along `dim`. | ||
| fn float_cat(tensors: Vec<FloatTensor<B>>, dim: usize) -> FloatTensor<B> { | ||
| cat_with_slice_assign::<B, Float>( | ||
| tensors.into_iter().map(TensorPrimitive::Float).collect(), | ||
| dim, | ||
| ) | ||
| .tensor() | ||
| } | ||
| /// Gets the indices of the maximum elements of a tensor along an axis. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the maximum elements of. | ||
| /// * `dim` - The dimension along which to get the maximum elements. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the indices of the maximum elements of `tensor` along `dim`. | ||
| fn float_argmax(tensor: FloatTensor<B>, dim: usize) -> IntTensor<B>; | ||
| /// Gets the indices of the minimum elements of a tensor along an axis. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the minimum elements of. | ||
| /// * `dim` - The dimension along which to get the minimum elements. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the indices of the minimum elements of `tensor` along `dim`. | ||
| fn float_argmin(tensor: FloatTensor<B>, dim: usize) -> IntTensor<B>; | ||
| /// Gets the maximum element of a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the maximum elements of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the maximum element of `tensor`. | ||
| fn float_max(tensor: FloatTensor<B>) -> FloatTensor<B> { | ||
| let shape = tensor.shape(); | ||
| let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()])); | ||
| B::float_max_dim(tensor, 0) | ||
| } | ||
| /// Gets the maximum elements of a tensor along an axis. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the maximum elements of. | ||
| /// * `dim` - The dimension along which to get the maximum elements. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the maximum elements of `tensor` along `dim`. | ||
| fn float_max_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> { | ||
| let index = B::float_argmax(tensor.clone(), dim); | ||
| B::float_gather(dim, tensor, index) | ||
| } | ||
| /// Gets the maximum elements of a tensor along an axis and their indices. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the maximum elements of. | ||
| /// * `dim` - The dimension along which to get the maximum elements. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tuple with the maximum elements of `tensor` along `dim` and their indices. | ||
| fn float_max_dim_with_indices( | ||
| tensor: FloatTensor<B>, | ||
| dim: usize, | ||
| ) -> (FloatTensor<B>, IntTensor<B>) { | ||
| let index = B::float_argmax(tensor.clone(), dim); | ||
| let values = B::float_gather(dim, tensor, index.clone()); | ||
| (values, index) | ||
| } | ||
| /// Gets the minimum element of a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the minimum elements of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the minimum element of `tensor`. | ||
| fn float_min(tensor: FloatTensor<B>) -> FloatTensor<B> { | ||
| let shape = tensor.shape(); | ||
| let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()])); | ||
| B::float_min_dim(tensor, 0) | ||
| } | ||
| /// Gets the minimum elements of a tensor along an axis. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the minimum elements of. | ||
| /// * `dim` - The dimension along which to get the minimum elements. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the minimum elements of `tensor` along `dim`. | ||
| fn float_min_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> { | ||
| let index = B::float_argmin(tensor.clone(), dim); | ||
| B::float_gather(dim, tensor, index) | ||
| } | ||
| /// Gets the minimum elements of a tensor along an axis and their indices. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the minimum elements of. | ||
| /// * `dim` - The dimension along which to get the minimum elements. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tuple with the minimum elements of `tensor` along `dim` and their indices. | ||
| fn float_min_dim_with_indices( | ||
| tensor: FloatTensor<B>, | ||
| dim: usize, | ||
| ) -> (FloatTensor<B>, IntTensor<B>) { | ||
| let index = B::float_argmin(tensor.clone(), dim); | ||
| let values = B::float_gather(dim, tensor, index.clone()); | ||
| (values, index) | ||
| } | ||
| /// Gets the maximum absolute element of a tensor. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the maximum elements of. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the maximum element of `tensor`. | ||
| fn float_max_abs(tensor: FloatTensor<B>) -> FloatTensor<B> { | ||
| let shape = tensor.shape(); | ||
| let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()])); | ||
| B::float_max_abs_dim(tensor, 0) | ||
| } | ||
| /// Gets the maximum absolute elements of a tensor along an axis. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to get the maximum elements of. | ||
| /// * `dim` - The dimension along which to get the maximum elements. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the maximum elements of `tensor` along `dim`. | ||
| fn float_max_abs_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> { | ||
| B::float_max_dim(B::float_abs(tensor), dim) | ||
| } | ||
| /// Tests if any element in the float `tensor` evaluates to True. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to test. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise. | ||
| fn float_any(tensor: FloatTensor<B>) -> BoolTensor<B> { | ||
| let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem()); | ||
| let bool_tensor = B::bool_not(bool_tensor); | ||
| let sum = B::float_sum(B::bool_into_float(bool_tensor)); | ||
| B::float_greater_elem(sum, 0.0f32.elem()) | ||
| } | ||
| /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to test. | ||
| /// * `dim` - The axis along which to test. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis | ||
| /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the | ||
| /// input evaluates to True, False otherwise. | ||
| fn float_any_dim(tensor: FloatTensor<B>, dim: usize) -> BoolTensor<B> { | ||
| let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem()); | ||
| let bool_tensor = B::bool_not(bool_tensor); | ||
| let sum = B::float_sum_dim(B::bool_into_float(bool_tensor), dim); | ||
| B::float_greater_elem(sum, 0.0f32.elem()) | ||
| } | ||
| /// Tests if all elements in the float `tensor` evaluate to True. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to test. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor | ||
| /// evaluate to True, False otherwise. | ||
| fn float_all(tensor: FloatTensor<B>) -> BoolTensor<B> { | ||
| let num_elems = tensor.shape().num_elements(); | ||
| let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem()); | ||
| let bool_tensor = B::bool_not(bool_tensor); | ||
| let sum = B::float_sum(B::bool_into_float(bool_tensor)); | ||
| B::float_equal_elem(sum, (num_elems as f32).elem()) | ||
| } | ||
| /// Tests if all elements in the float `tensor` evaluate to True along a given dimension `dim`. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to test. | ||
| /// * `dim` - The axis along which to test. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis | ||
| /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input | ||
| /// evaluates to True, False otherwise. | ||
| fn float_all_dim(tensor: FloatTensor<B>, dim: usize) -> BoolTensor<B> { | ||
| let num_elems = tensor.shape().dims[dim]; | ||
| let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem()); | ||
| let bool_tensor = B::bool_not(bool_tensor); | ||
| let sum = B::float_sum_dim(B::bool_into_float(bool_tensor), dim); | ||
| B::float_equal_elem(sum, (num_elems as f32).elem()) | ||
| } | ||
| /// Returns the signs of the float `tensor`. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor to extract the signs from. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`. | ||
| fn float_sign(tensor: FloatTensor<B>) -> FloatTensor<B> { | ||
| let zeros = B::float_zeros( | ||
| tensor.shape(), | ||
| &B::float_device(&tensor), | ||
| tensor.dtype().into(), | ||
| ); | ||
| let less_than_zero = B::float_lower_elem(tensor.clone(), 0.0f32.elem()); | ||
| let greater_than_zero = B::float_greater_elem(tensor, 0.0f32.elem()); | ||
| let mut result = B::float_mask_fill(zeros, less_than_zero, (-1.0f32).elem()); | ||
| result = B::float_mask_fill(result, greater_than_zero, 1.0f32.elem()); | ||
| result | ||
| } | ||
| /// Broadcasts the float `tensor` to the given `shape`. | ||
| fn float_expand(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>; | ||
| /// Sort the elements of the input `tensor` by value in along a given dimension. | ||
| /// | ||
| /// This sort is unstable (i.e., may reorder equal elements). | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The input tensor. | ||
| /// * `dim` - The axis along which to sort. | ||
| /// * `descending` - The sorting order. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as the input tensor, where the elements are sorted by value. | ||
| fn float_sort(tensor: FloatTensor<B>, dim: usize, descending: bool) -> FloatTensor<B> { | ||
| sort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending).tensor() | ||
| } | ||
| /// Sort the elements of the input `tensor` by value in along a given dimension. | ||
| /// | ||
| /// This sort is unstable (i.e., may reorder equal elements). | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The input tensor. | ||
| /// * `dim` - The axis along which to sort. | ||
| /// * `descending` - The sorting order. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as the input tensor and corresponding indices, where | ||
| /// the elements are sorted by value and the indices map back to the original input tensor. | ||
| fn float_sort_with_indices( | ||
| tensor: FloatTensor<B>, | ||
| dim: usize, | ||
| descending: bool, | ||
| ) -> (FloatTensor<B>, IntTensor<B>) { | ||
| let (values, indices) = | ||
| sort_with_indices::<B, Float>(TensorPrimitive::Float(tensor), dim, descending); | ||
| (values.tensor(), indices) | ||
| } | ||
| /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension. | ||
| /// | ||
| /// This sort is unstable (i.e., may reorder equal elements). | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The input tensor. | ||
| /// * `dim` - The axis along which to sort. | ||
| /// * `descending` - The sorting order. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with the same shape as the input tensor the indices map back to the original input tensor. | ||
| fn float_argsort(tensor: FloatTensor<B>, dim: usize, descending: bool) -> IntTensor<B> { | ||
| argsort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending) | ||
| } | ||
| /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values, | ||
| /// using the given locations in [-1, 1]. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor being sampled from, shape (N, C, H_in, W_in) | ||
| /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1]. | ||
| /// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right | ||
| /// * `options` - Grid sampling options (mode, padding_mode, align_corners) | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor with shape (N, C, H_out, W_out) | ||
| fn float_grid_sample_2d( | ||
| tensor: FloatTensor<B>, | ||
| grid: FloatTensor<B>, | ||
| options: GridSampleOptions, | ||
| ) -> FloatTensor<B> { | ||
| float_grid_sample_2d_ref::<B>(tensor, grid, options) | ||
| } | ||
| /// Unfold windows along a dimension. | ||
| /// | ||
| /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`; | ||
| /// where windows are advanced by `step` at each index. | ||
| /// | ||
| /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]`` | ||
| /// * `dim` - the selected dim. | ||
| /// * `size` - the size of each unfolded window. | ||
| /// * `step` - the step between each window. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A tensor view with shape ``[pre=..., windows, size, post=...]``. | ||
| fn float_unfold(tensor: FloatTensor<B>, dim: usize, size: usize, step: usize) | ||
| -> FloatTensor<B>; | ||
| /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value. | ||
| fn float_is_nan(tensor: FloatTensor<B>) -> BoolTensor<B> { | ||
| // Check if the input tensor is NaN by comparing it to itself | ||
| // NaN is the only value that is not equal to itself | ||
| B::float_not_equal(tensor.clone(), tensor) | ||
| } | ||
| /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF). | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor where `true` indicates that the value is infinite | ||
| fn float_is_inf(tensor: FloatTensor<B>) -> BoolTensor<B> { | ||
| B::float_equal_elem(B::float_abs(tensor), f64::INFINITY.elem()) | ||
| } | ||
| } |
| use alloc::vec::Vec; | ||
| use core::future::Future; | ||
| use super::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}; | ||
| use crate::{ | ||
| TensorData, | ||
| backend::{Backend, ExecutionError}, | ||
| }; | ||
| #[derive(Default)] | ||
| /// Contains all tensor primitives that are going to be read. | ||
| pub struct TransactionPrimitive<B: Backend> { | ||
| /// Float tensors. | ||
| pub read_floats: Vec<FloatTensor<B>>, | ||
| /// Quantized tensors. | ||
| pub read_qfloats: Vec<QuantizedTensor<B>>, | ||
| /// Int tensors. | ||
| pub read_ints: Vec<IntTensor<B>>, | ||
| /// Bool tensors. | ||
| pub read_bools: Vec<BoolTensor<B>>, | ||
| } | ||
| #[derive(Default)] | ||
| /// Contains all [data](TensorData) related to a [transaction](TransactionPrimitive). | ||
| pub struct TransactionPrimitiveData { | ||
| /// Float tensor data. | ||
| pub read_floats: Vec<TensorData>, | ||
| /// Quantized tensor data. | ||
| pub read_qfloats: Vec<TensorData>, | ||
| /// Int tensor data. | ||
| pub read_ints: Vec<TensorData>, | ||
| /// Bool tensor data. | ||
| pub read_bools: Vec<TensorData>, | ||
| } | ||
| /// Operations that are sync by nature and that can be batch together in transactions to improve | ||
| /// compute utilization with efficient laziness. | ||
| pub trait TransactionOps<B: Backend> { | ||
| /// Executes a [transaction](TransactionPrimitive) and return its | ||
| /// [data](TransactionPrimitiveData). | ||
| fn tr_execute( | ||
| transaction: TransactionPrimitive<B>, | ||
| ) -> impl Future<Output = Result<TransactionPrimitiveData, ExecutionError>> + Send { | ||
| async move { | ||
| let mut floats = Vec::new(); | ||
| let mut qfloats = Vec::new(); | ||
| let mut ints = Vec::new(); | ||
| let mut bools = Vec::new(); | ||
| for t in transaction.read_floats { | ||
| floats.push(B::float_into_data(t).await?); | ||
| } | ||
| for t in transaction.read_qfloats { | ||
| qfloats.push(B::q_into_data(t).await?); | ||
| } | ||
| for t in transaction.read_ints { | ||
| ints.push(B::int_into_data(t).await?); | ||
| } | ||
| for t in transaction.read_bools { | ||
| bools.push(B::bool_into_data(t).await?); | ||
| } | ||
| Ok(TransactionPrimitiveData { | ||
| read_floats: floats, | ||
| read_qfloats: qfloats, | ||
| read_ints: ints, | ||
| read_bools: bools, | ||
| }) | ||
| } | ||
| } | ||
| } |
| use crate::{Tensor, backend::Backend}; | ||
| /// The observed input calibration range. | ||
| #[derive(Clone, Debug)] | ||
| pub struct CalibrationRange<B: Backend> { | ||
| /// Minimum observed value(s). | ||
| pub min: Tensor<B, 1>, | ||
| /// Maximum observed value(s). | ||
| pub max: Tensor<B, 1>, | ||
| } | ||
| /// Calibration method used to compute the quantization range mapping. | ||
| pub enum Calibration { | ||
| /// Computes quantization range mapping based on the min and max values. | ||
| MinMax, | ||
| } |
| mod calibration; | ||
| mod parameters; | ||
| mod primitive; | ||
| mod scheme; | ||
| pub use burn_std::quantization::QuantizedBytes; | ||
| pub use calibration::*; | ||
| pub use parameters::*; | ||
| pub use primitive::*; | ||
| pub use scheme::*; |
| use crate::{DType, Shape, Tensor, backend::Backend}; | ||
| use alloc::vec::Vec; | ||
| pub use burn_std::quantization::QParams; | ||
| /// The tensor quantization parameters. | ||
| pub type QuantizationParameters<B> = QParams<Tensor<B, 1>>; | ||
| /// The quantization parameters primitive. | ||
| /// | ||
| /// # Remarks | ||
| /// | ||
| /// This is a low-level struct used internally by the library to provide the quantization parameters | ||
| /// to the backends. It is not designed for direct usage by users, and not recommended to import | ||
| /// or use this struct directly. | ||
| /// | ||
| /// Users should prefer the [QuantizationParameters] struct, which is designed for public use. | ||
| pub struct QuantizationParametersPrimitive<B: Backend> { | ||
| /// The scaling factor. | ||
| pub scales: B::FloatTensorPrimitive, | ||
| } | ||
| impl<B: Backend> From<QuantizationParameters<B>> for QuantizationParametersPrimitive<B> { | ||
| fn from(value: QuantizationParameters<B>) -> Self { | ||
| QuantizationParametersPrimitive { | ||
| scales: value.scales.primitive.tensor(), | ||
| } | ||
| } | ||
| } | ||
| /// A quantization parameter tensor descriptor. | ||
| #[derive(Debug, Clone, PartialEq, Eq)] | ||
| pub struct QParamTensor { | ||
| /// Start of the tensor in the buffer | ||
| pub offset_start: usize, | ||
| /// Offset of tensor end from the end of the buffer | ||
| pub offset_end: usize, | ||
| /// Shape of the tensor | ||
| pub shape: Shape, | ||
| /// Strides of the tensor | ||
| pub strides: Vec<usize>, | ||
| /// Data type of the tensor | ||
| pub dtype: DType, | ||
| } |
| use crate::quantization::{QuantAcc, QuantPropagation, QuantScheme}; | ||
| /// Quantized tensor primitive. | ||
| pub trait QTensorPrimitive { | ||
| /// Returns the quantization settings for the given tensor. | ||
| fn scheme(&self) -> &QuantScheme; | ||
| /// The precision used for the accumulation in various kernels. | ||
| fn acc_precision(&self) -> QuantAcc { | ||
| QuantAcc::F32 | ||
| } | ||
| /// How quantization is propagated during computation. | ||
| fn propagation(&self) -> QuantPropagation { | ||
| QuantPropagation::Inhibit | ||
| } | ||
| /// Returns the default tensor quantization scheme. | ||
| fn default_scheme() -> QuantScheme { | ||
| QuantScheme::default() | ||
| } | ||
| } |
| // We re-export those types. | ||
| pub use burn_std::quantization::{ | ||
| BlockSize, QuantLevel, QuantMode, QuantParam, QuantScheme, QuantStore, QuantValue, | ||
| }; | ||
| use serde::{Deserialize, Serialize}; | ||
| use crate::{Shape, Tensor, TensorMetadata, TensorPrimitive, backend::Backend}; | ||
| use super::{ | ||
| Calibration, CalibrationRange, QuantizationParameters, QuantizationParametersPrimitive, | ||
| }; | ||
| #[derive( | ||
| Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default, | ||
| )] | ||
| /// The precision of accumulating elements. | ||
| pub enum QuantAcc { | ||
| /// Full precision. | ||
| #[default] | ||
| F32, | ||
| /// Half precision. | ||
| F16, | ||
| /// bfloat16 precision. | ||
| BF16, | ||
| } | ||
| /// Specify if the output of an operation is quantized using the scheme of the input | ||
| /// or returned unquantized. | ||
| #[derive( | ||
| Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default, | ||
| )] | ||
| pub enum QuantPropagation { | ||
| /// The output is quantized using the scheme of the input. | ||
| Propagate, | ||
| /// The output is not quantized. | ||
| #[default] | ||
| Inhibit, | ||
| } | ||
| /// Compute the quantization range mapping. | ||
| pub fn compute_range<B: Backend, const D: usize>( | ||
| scheme: &QuantScheme, | ||
| tensor: &Tensor<B, D>, | ||
| calibration: &Calibration, | ||
| ) -> CalibrationRange<B> { | ||
| let (min, max) = match &tensor.primitive { | ||
| TensorPrimitive::Float(tensor) => { | ||
| compute_range_primitive::<B>(scheme, tensor.clone(), calibration) | ||
| } | ||
| TensorPrimitive::QFloat(_) => unreachable!(), | ||
| }; | ||
| CalibrationRange { | ||
| min: Tensor::from_primitive(TensorPrimitive::Float(min)), | ||
| max: Tensor::from_primitive(TensorPrimitive::Float(max)), | ||
| } | ||
| } | ||
| /// Calculate the shape of the quantization parameters for a given tensor and level | ||
| pub fn params_shape(data_shape: &Shape, level: QuantLevel) -> Shape { | ||
| match level { | ||
| QuantLevel::Tensor => Shape::new([1]), | ||
| QuantLevel::Block(block_size) => { | ||
| let mut params_shape = data_shape.clone(); | ||
| let block_size = block_size.to_dim_vec(data_shape.num_dims()); | ||
| for (shape, block_size) in params_shape.dims.iter_mut().zip(block_size) { | ||
| *shape = (*shape).div_ceil(block_size as usize); | ||
| } | ||
| params_shape | ||
| } | ||
| } | ||
| } | ||
| pub(crate) fn compute_range_primitive<B: Backend>( | ||
| scheme: &QuantScheme, | ||
| tensor: B::FloatTensorPrimitive, | ||
| calibration: &Calibration, | ||
| ) -> (B::FloatTensorPrimitive, B::FloatTensorPrimitive) { | ||
| match calibration { | ||
| Calibration::MinMax => match scheme.level { | ||
| QuantLevel::Tensor => (B::float_min(tensor.clone()), B::float_max(tensor)), | ||
| QuantLevel::Block(block_size) => { | ||
| let block_elems = block_size.num_elements(); | ||
| let shape = tensor.shape(); | ||
| let numel = shape.num_elements(); | ||
| assert_eq!( | ||
| numel % block_elems, | ||
| 0, | ||
| "Tensor {shape:?} must be evenly divisible by block size {block_elems}" | ||
| ); | ||
| let num_blocks = numel / block_elems; | ||
| let params_shape = params_shape(&shape, scheme.level); | ||
| let blocks = B::float_reshape(tensor, Shape::new([num_blocks, block_elems])); | ||
| let blocks_min = | ||
| B::float_reshape(B::float_min_dim(blocks.clone(), 1), params_shape.clone()); | ||
| let blocks_max = B::float_reshape(B::float_max_dim(blocks, 1), params_shape); | ||
| (blocks_min, blocks_max) | ||
| } | ||
| }, | ||
| } | ||
| } | ||
| /// Compute the quantization parameters. | ||
| pub fn compute_q_params<B: Backend>( | ||
| scheme: &QuantScheme, | ||
| range: CalibrationRange<B>, | ||
| ) -> QuantizationParameters<B> { | ||
| match scheme { | ||
| QuantScheme { | ||
| level: QuantLevel::Tensor | QuantLevel::Block(_), | ||
| mode: QuantMode::Symmetric, | ||
| .. | ||
| } => { | ||
| // Quantized range `[a, b]` | ||
| let (a, b) = scheme.value.range(); | ||
| // Compute scale to convert an input value in range `[-alpha, alpha]` | ||
| let values_range = range.min.abs().max_pair(range.max.abs()).mul_scalar(2); | ||
| QuantizationParameters { | ||
| scales: values_range.div_scalar(b - a), | ||
| } | ||
| } | ||
| } | ||
| } | ||
| /// Compute the quantization parameters. | ||
| pub(crate) fn compute_q_params_primitive<B: Backend>( | ||
| scheme: &QuantScheme, | ||
| min: B::FloatTensorPrimitive, | ||
| max: B::FloatTensorPrimitive, | ||
| ) -> QuantizationParametersPrimitive<B> { | ||
| let range = CalibrationRange { | ||
| min: Tensor::from_primitive(TensorPrimitive::Float(min)), | ||
| max: Tensor::from_primitive(TensorPrimitive::Float(max)), | ||
| }; | ||
| compute_q_params(scheme, range).into() | ||
| } |
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display