burn-tensor
Advanced tools
| #[burn_tensor_testgen::testgen(transaction)] | ||
| mod tests { | ||
| use super::*; | ||
| use burn_tensor::{Transaction, ops::FloatElem}; | ||
| type FT = FloatElem<TestBackend>; | ||
| // https://github.com/tracel-ai/burn/issues/4021 | ||
| #[test] | ||
| fn should_support_transaction() { | ||
| let rows = 261120; | ||
| let cols = 408; | ||
| let device = Default::default(); | ||
| let j = TestTensor::<2>::zeros([rows, cols], &device); | ||
| let jt = j.clone().transpose(); | ||
| let g = jt.matmul(j); | ||
| let g = g.transpose(); | ||
| let expected = g.to_data(); | ||
| assert_eq!(g.shape().dims(), [cols, cols]); | ||
| // Fails | ||
| let [data] = Transaction::default() | ||
| .register(g) | ||
| .execute() | ||
| .try_into() | ||
| .unwrap(); | ||
| // check byte equality | ||
| assert_eq!(data, expected); | ||
| } | ||
| } |
| { | ||
| "git": { | ||
| "sha1": "88d662d8bfaedc0d83a2151b184d599b37ead065" | ||
| "sha1": "0368cc660dc9fc084292795ffe6b4e060d5aa668" | ||
| }, | ||
| "path_in_vcs": "crates/burn-tensor" | ||
| } |
+18
-58
@@ -15,3 +15,3 @@ # THIS FILE IS AUTOMATICALLY GENERATED BY CARGO | ||
| name = "burn-tensor" | ||
| version = "0.20.0-pre.3" | ||
| version = "0.20.0-pre.4" | ||
| authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"] | ||
@@ -54,3 +54,7 @@ build = false | ||
| [features] | ||
| cubecl = ["dep:cubecl"] | ||
| cubecl = [ | ||
| "dep:cubecl", | ||
| "burn-std/cubecl", | ||
| "burn-backend/cubecl", | ||
| ] | ||
| cubecl-cpu = [ | ||
@@ -72,6 +76,3 @@ "cubecl", | ||
| ] | ||
| default = [ | ||
| "std", | ||
| "burn-common/rayon", | ||
| ] | ||
| default = ["std"] | ||
| doc = ["default"] | ||
@@ -84,6 +85,5 @@ experimental-named-tensor = [] | ||
| std = [ | ||
| "rand/std", | ||
| "half/std", | ||
| "num-traits/std", | ||
| "burn-common/std", | ||
| "burn-std/std", | ||
| "burn-backend/std", | ||
| "colored", | ||
@@ -96,14 +96,14 @@ ] | ||
| [dependencies.burn-common] | ||
| version = "=0.20.0-pre.3" | ||
| [dependencies.burn-backend] | ||
| version = "0.20.0-pre.4" | ||
| default-features = false | ||
| [dependencies.burn-std] | ||
| version = "0.20.0-pre.4" | ||
| default-features = false | ||
| [dependencies.burn-tensor-testgen] | ||
| version = "=0.20.0-pre.3" | ||
| version = "0.20.0-pre.4" | ||
| optional = true | ||
| [dependencies.bytemuck] | ||
| version = "1.24.0" | ||
| features = ["extern_crate_alloc"] | ||
| [dependencies.colored] | ||
@@ -114,10 +114,6 @@ version = "3.0.0" | ||
| [dependencies.cubecl] | ||
| version = "=0.9.0-pre.3" | ||
| version = "=0.9.0-pre.4" | ||
| optional = true | ||
| default-features = false | ||
| [dependencies.cubecl-quant] | ||
| version = "=0.9.0-pre.3" | ||
| default-features = false | ||
| [dependencies.derive-new] | ||
@@ -127,14 +123,4 @@ version = "0.7.0" | ||
| [dependencies.half] | ||
| version = "2.7.1" | ||
| features = [ | ||
| "alloc", | ||
| "num-traits", | ||
| "serde", | ||
| "bytemuck", | ||
| ] | ||
| default-features = false | ||
| [dependencies.hashbrown] | ||
| version = "0.15.5" | ||
| version = "0.16" | ||
@@ -146,10 +132,2 @@ [dependencies.num-traits] | ||
| [dependencies.rand] | ||
| version = "0.9.2" | ||
| default-features = false | ||
| [dependencies.rand_distr] | ||
| version = "0.5.1" | ||
| default-features = false | ||
| [dependencies.serde] | ||
@@ -163,23 +141,5 @@ version = "1.0.228" | ||
| [dependencies.serde_bytes] | ||
| version = "0.11.18" | ||
| features = ["alloc"] | ||
| default-features = false | ||
| [dev-dependencies.bincode] | ||
| version = "2.0.1" | ||
| features = [ | ||
| "alloc", | ||
| "serde", | ||
| ] | ||
| default-features = false | ||
| [dev-dependencies.paste] | ||
| version = "1" | ||
| [dev-dependencies.rand] | ||
| version = "0.9.2" | ||
| features = ["thread_rng"] | ||
| default-features = false | ||
| [lints.clippy] | ||
@@ -186,0 +146,0 @@ |
+4
-66
@@ -24,73 +24,11 @@ #![cfg_attr(not(feature = "std"), no_std)] | ||
| pub use half::{bf16, f16}; | ||
| pub(crate) use tensor::check::macros::check; | ||
| pub use tensor::*; | ||
| pub use burn_common::stream_id::StreamId; | ||
| pub use burn_std::stream_id::StreamId; | ||
| pub use burn_common::reader::*; // Useful so that backends don't have to add `burn_common` as a dependency. | ||
| // Re-exported types | ||
| pub use burn_std::reader::*; // Useful so that backends don't have to add `burn_std` as a dependency. | ||
| pub use burn_std::{Bytes, bf16, f16}; | ||
| #[cfg(feature = "cubecl")] | ||
| pub use cubecl::flex32; | ||
| #[cfg(feature = "cubecl")] | ||
| mod cube { | ||
| use cubecl::ir::{ElemType, FloatKind, IntKind, StorageType, UIntKind}; | ||
| use cubecl_quant::scheme::QuantScheme; | ||
| use crate::quantization::{QuantStore, QuantValue}; | ||
| impl From<crate::DType> for cubecl::ir::ElemType { | ||
| fn from(dtype: crate::DType) -> Self { | ||
| match dtype { | ||
| crate::DType::F64 => ElemType::Float(FloatKind::F64), | ||
| crate::DType::F32 => ElemType::Float(FloatKind::F32), | ||
| crate::DType::Flex32 => ElemType::Float(FloatKind::Flex32), | ||
| crate::DType::F16 => ElemType::Float(FloatKind::F16), | ||
| crate::DType::BF16 => ElemType::Float(FloatKind::BF16), | ||
| crate::DType::I64 => ElemType::Int(IntKind::I64), | ||
| crate::DType::I32 => ElemType::Int(IntKind::I32), | ||
| crate::DType::I16 => ElemType::Int(IntKind::I16), | ||
| crate::DType::I8 => ElemType::Int(IntKind::I8), | ||
| crate::DType::U64 => ElemType::UInt(UIntKind::U64), | ||
| crate::DType::U32 => ElemType::UInt(UIntKind::U32), | ||
| crate::DType::U16 => ElemType::UInt(UIntKind::U16), | ||
| crate::DType::U8 => ElemType::UInt(UIntKind::U8), | ||
| crate::DType::Bool => ElemType::Bool, | ||
| crate::DType::QFloat(scheme) => match scheme.store { | ||
| QuantStore::Native => match scheme.value { | ||
| QuantValue::Q8F | QuantValue::Q8S => Self::Int(IntKind::I8), | ||
| QuantValue::E4M3 => Self::Float(FloatKind::E4M3), | ||
| QuantValue::E5M2 => Self::Float(FloatKind::E5M2), | ||
| QuantValue::Q4F | ||
| | QuantValue::Q4S | ||
| | QuantValue::Q2F | ||
| | QuantValue::Q2S | ||
| | QuantValue::E2M1 => { | ||
| panic!("Can't store native sub-byte values") | ||
| } | ||
| }, | ||
| QuantStore::U32 => Self::UInt(UIntKind::U32), | ||
| }, | ||
| } | ||
| } | ||
| } | ||
| impl From<crate::DType> for cubecl::ir::StorageType { | ||
| fn from(dtype: crate::DType) -> cubecl::ir::StorageType { | ||
| match dtype { | ||
| crate::DType::QFloat(QuantScheme { | ||
| store: QuantStore::Native, | ||
| value: QuantValue::E2M1, | ||
| .. | ||
| }) => StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2), | ||
| _ => { | ||
| let elem: ElemType = dtype.into(); | ||
| elem.into() | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| #[cfg(feature = "cubecl-wgpu")] | ||
@@ -97,0 +35,0 @@ mod cube_wgpu { |
@@ -821,19 +821,6 @@ use crate::ops::FloatElem; | ||
| let range = slice.to_range(d_tensor); | ||
| // Empty slices (start >= end) are allowed and produce a tensor with size 0 | ||
| // in that dimension. This matches PyTorch behavior and is required for ONNX | ||
| // compatibility where dynamic slice ranges may become empty at runtime. | ||
| if range.start >= range.end { | ||
| check = check.register( | ||
| "Slice", | ||
| TensorError::new( | ||
| "The provided slice has a range where the start index is bigger or \ | ||
| equal to its end.", | ||
| ) | ||
| .details(format!( | ||
| "The range at dimension '{}' starts at '{}' and is greater or equal to \ | ||
| its end '{}'. Tensor shape {:?}.", | ||
| i, range.start, range.end, shape.dims, | ||
| )), | ||
| ); | ||
| } | ||
| if slice.step() == 0 { | ||
@@ -840,0 +827,0 @@ check = check.register( |
| use crate::AsIndex; | ||
| use crate::FloatDType; | ||
| use crate::Tensor; | ||
| use crate::canonicalize_dim; | ||
| use crate::cast::ToElement; | ||
| use crate::check; | ||
| use crate::check::TensorCheck; | ||
| use crate::indexing::canonicalize_dim; | ||
| use crate::ops::InterpolateMode; | ||
| use crate::ops::GridSampleOptions; | ||
| use crate::quantization::{QuantScheme, QuantizationParameters}; | ||
@@ -658,11 +658,7 @@ use crate::tensor::backend::Backend; | ||
| /// | ||
| /// Interpolation is bilinear. | ||
| /// Padding is border: out of bounds locations will be clamped to the nearest border | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `tensor` - The tensor being sampled from, shape (N, C, H_in, W_in) | ||
| /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1]. | ||
| /// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right | ||
| /// * `method` - How to interpolate between samples | ||
| /// * `options` - Grid sampling options (mode, padding_mode, align_corners) | ||
| /// | ||
@@ -672,7 +668,22 @@ /// # Returns | ||
| /// A tensor with shape (N, C, H_out, W_out) | ||
| pub fn grid_sample_2d(self, grid: Tensor<B, D>, method: InterpolateMode) -> Tensor<B, D> { | ||
| /// | ||
| /// # Example | ||
| /// | ||
| /// ```ignore | ||
| /// use burn_tensor::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode}; | ||
| /// | ||
| /// // Default options (bilinear, zeros padding, align_corners=false) | ||
| /// let output = tensor.grid_sample_2d(grid, GridSampleOptions::default()); | ||
| /// | ||
| /// // Custom options | ||
| /// let options = GridSampleOptions::new(InterpolateMode::Bilinear) | ||
| /// .with_padding_mode(GridSamplePaddingMode::Border) | ||
| /// .with_align_corners(true); | ||
| /// let output = tensor.grid_sample_2d(grid, options); | ||
| /// ``` | ||
| pub fn grid_sample_2d(self, grid: Tensor<B, D>, options: GridSampleOptions) -> Tensor<B, D> { | ||
| Tensor::new(TensorPrimitive::Float(B::float_grid_sample_2d( | ||
| self.primitive.tensor(), | ||
| grid.primitive.tensor(), | ||
| method, | ||
| options, | ||
| ))) | ||
@@ -679,0 +690,0 @@ } |
@@ -13,3 +13,2 @@ pub(crate) mod check; | ||
| mod numeric; | ||
| mod slice; | ||
| mod sort; | ||
@@ -27,4 +26,3 @@ mod take; | ||
| pub use numeric::*; | ||
| pub use slice::*; | ||
| pub use sort::{argsort, sort, sort_with_indices}; | ||
| pub use transaction::*; |
@@ -9,3 +9,3 @@ use core::cmp::Ordering; | ||
| use alloc::{vec, vec::Vec}; | ||
| use burn_common::reader::try_read_sync; | ||
| use burn_std::reader::try_read_sync; | ||
@@ -41,3 +41,6 @@ /// Sort the elements of the input `tensor` by value along a given dimension. | ||
| let device = K::device(&tensor); | ||
| let data = try_read_sync(K::into_data_async(tensor)).expect("Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation."); | ||
| let msg = "Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation."; | ||
| let data = try_read_sync(K::into_data_async(tensor)) | ||
| .expect(msg) | ||
| .expect(msg); | ||
| sort_data::<B, K>(data, dim, &device, descending) | ||
@@ -97,3 +100,6 @@ } | ||
| let device = K::device(&tensor); | ||
| let data = try_read_sync(K::into_data_async(tensor)).expect("Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation."); | ||
| let msg = "Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation."; | ||
| let data = try_read_sync(K::into_data_async(tensor)) | ||
| .expect(msg) | ||
| .expect(msg); | ||
| sort_data_with_indices::<B, K>(data, dim, &device, descending) | ||
@@ -194,3 +200,6 @@ } | ||
| let device = K::device(&tensor); | ||
| let data = try_read_sync(K::into_data_async(tensor)).expect("Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation."); | ||
| let msg = "Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation."; | ||
| let data = try_read_sync(K::into_data_async(tensor)) | ||
| .expect(msg) | ||
| .expect(msg); | ||
@@ -197,0 +206,0 @@ argsort_data::<B, K>(data, dim, &device, descending) |
| use alloc::vec::Vec; | ||
| use crate::{ | ||
| BasicOps, Int, Tensor, | ||
| backend::Backend, | ||
| check, | ||
| check::TensorCheck, | ||
| indexing::{AsIndex, canonicalize_dim}, | ||
| AsIndex, BasicOps, Int, Tensor, backend::Backend, canonicalize_dim, check, check::TensorCheck, | ||
| }; | ||
@@ -10,0 +6,0 @@ |
| use super::{BasicOps, Tensor, TensorPrimitive}; | ||
| use crate::{ | ||
| TensorData, | ||
| backend::Backend, | ||
| backend::{Backend, ExecutionError}, | ||
| ops::{BoolTensor, IntTensor, TransactionPrimitive}, | ||
@@ -46,9 +46,20 @@ }; | ||
| pub fn execute(self) -> Vec<TensorData> { | ||
| burn_common::future::block_on(self.execute_async()) | ||
| burn_std::future::block_on(self.execute_async()) | ||
| .expect("Error while reading data: use `try_execute` to handle error at runtime") | ||
| } | ||
| /// Executes the transaction synchronously and returns the [data](TensorData) in the same | ||
| /// order in which they were [registered](Self::register). | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// Any error that might have occurred since the last time the device was synchronized. | ||
| pub fn try_execute(self) -> Result<Vec<TensorData>, ExecutionError> { | ||
| burn_std::future::block_on(self.execute_async()) | ||
| } | ||
| /// Executes the transaction asynchronously and returns the [data](TensorData) in the same order | ||
| /// in which they were [registered](Self::register). | ||
| pub async fn execute_async(self) -> Vec<TensorData> { | ||
| let result = B::tr_execute(self.op).await; | ||
| pub async fn execute_async(self) -> Result<Vec<TensorData>, ExecutionError> { | ||
| let result = B::tr_execute(self.op).await?; | ||
@@ -60,3 +71,4 @@ let mut floats: Vec<_> = result.read_floats.into_iter().map(Some).collect(); | ||
| self.orders | ||
| Ok(self | ||
| .orders | ||
| .into_iter() | ||
@@ -69,3 +81,3 @@ .map(|order| match order { | ||
| }) | ||
| .collect::<Vec<_>>() | ||
| .collect::<Vec<_>>()) | ||
| } | ||
@@ -72,0 +84,0 @@ |
| use alloc::string::String; | ||
| use serde::{Deserialize, Serialize}; | ||
@@ -123,3 +124,5 @@ use crate::tensor::Element; | ||
| /// Sync the backend, ensure that all computation are finished. | ||
| fn sync(_device: &Self::Device) {} | ||
| fn sync(_device: &Self::Device) -> Result<(), SyncError> { | ||
| Ok(()) | ||
| } | ||
@@ -138,2 +141,40 @@ /// Marks the given data as being used as a staging buffer for transfer between CPU and | ||
| /// An error that can happened when syncing a backend. | ||
| #[derive(Serialize, Deserialize, Debug)] | ||
| pub enum SyncError { | ||
| /// A generic error happened while syncing. | ||
| Generic { | ||
| /// The details | ||
| context: String, | ||
| }, | ||
| /// Syncing the device isn't supported. | ||
| NotSupported { | ||
| /// The details | ||
| context: String, | ||
| }, | ||
| } | ||
| /// An error that can happen when syncing a device. | ||
| #[derive(Serialize, Deserialize, Debug)] | ||
| pub enum ExecutionError { | ||
| /// A generic error happened while syncing. | ||
| Generic { | ||
| /// The details | ||
| context: String, | ||
| }, | ||
| } | ||
| impl core::fmt::Display for SyncError { | ||
| fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { | ||
| match self { | ||
| SyncError::Generic { context } => { | ||
| f.write_fmt(format_args!("An error happened while syncing: {}", context)) | ||
| } | ||
| SyncError::NotSupported { context } => { | ||
| f.write_fmt(format_args!("Can't sync the device: {}", context)) | ||
| } | ||
| } | ||
| } | ||
| } | ||
| /// Trait that allows a backend to support autodiff. | ||
@@ -140,0 +181,0 @@ pub trait AutodiffBackend: Backend { |
@@ -1,6 +0,6 @@ | ||
| pub use burn_common::device::*; | ||
| pub use burn_std::device::*; | ||
| /// The handle device trait allows to get an id for a backend device. | ||
| pub trait DeviceOps: | ||
| Clone + Default + PartialEq + Send + Sync + core::fmt::Debug + burn_common::device::Device | ||
| Clone + Default + PartialEq + Send + Sync + core::fmt::Debug + burn_std::device::Device | ||
| { | ||
@@ -7,0 +7,0 @@ /// Returns the [device id](DeviceId). |
| use crate::backend::Backend; | ||
| use crate::indexing::canonicalize_dim; | ||
| use crate::canonicalize_dim; | ||
| use crate::tensor::{BasicOps, Tensor}; | ||
@@ -4,0 +4,0 @@ use crate::{AsIndex, Numeric}; |
+7
-14
| pub(crate) mod stats; | ||
| mod api; | ||
| mod data; | ||
| mod distribution; | ||
| mod element; | ||
| mod shape; | ||
| pub use api::*; | ||
| pub use burn_common::bytes::*; | ||
| pub use data::*; | ||
| pub use distribution::*; | ||
| pub use element::*; | ||
| pub use shape::*; | ||
| // Re-exported types | ||
| pub use burn_backend::{DataError, TensorData, Tolerance, distribution::*, element::*}; | ||
| pub use burn_std::{ | ||
| DType, FloatDType, IntDType, s, | ||
| tensor::{indexing::*, shape::*, slice::*}, | ||
| }; | ||
| /// The activation module. | ||
@@ -28,7 +26,2 @@ pub mod activation; | ||
| /// The indexing module. | ||
| pub mod indexing; | ||
| pub use indexing::AsIndex; | ||
| /// The linalg module. | ||
@@ -35,0 +28,0 @@ pub mod linalg; |
@@ -6,3 +6,4 @@ use super::{ | ||
| use crate::{ | ||
| Bool, ElementConversion, TensorData, TensorMetadata, argwhere_data, backend::Backend, | ||
| Bool, ElementConversion, TensorData, TensorMetadata, argwhere_data, | ||
| backend::{Backend, ExecutionError}, | ||
| tensor::Shape, | ||
@@ -61,3 +62,5 @@ }; | ||
| /// The data structure with the tensor's data. | ||
| fn bool_into_data(tensor: BoolTensor<B>) -> impl Future<Output = TensorData> + Send; | ||
| fn bool_into_data( | ||
| tensor: BoolTensor<B>, | ||
| ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send; | ||
@@ -134,2 +137,7 @@ /// Creates a tensor from the data structure. | ||
| /// The tensor with the values for the given slices. | ||
| /// | ||
| /// # Note | ||
| /// | ||
| /// Empty slices (where start >= end) are handled at the high-level tensor API and will not | ||
| /// be passed to this method. Backend implementations do not need to handle empty slices. | ||
| fn bool_slice(tensor: BoolTensor<B>, slices: &[crate::Slice]) -> BoolTensor<B>; | ||
@@ -432,3 +440,5 @@ | ||
| let device = B::bool_device(&tensor); | ||
| let data = B::bool_into_data(tensor).await; | ||
| let data = B::bool_into_data(tensor) | ||
| .await | ||
| .expect("Can read the data without error"); | ||
| argwhere_data::<B>(data, &device) | ||
@@ -435,0 +445,0 @@ } |
| use super::cat::cat_with_slice_assign; | ||
| use super::repeat_dim::repeat_with_slice_assign; | ||
| use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor}; | ||
| use crate::backend::ExecutionError; | ||
| use crate::{ | ||
@@ -36,3 +37,5 @@ Distribution, ElementConversion, Int, IntDType, TensorData, backend::Backend, tensor::Shape, | ||
| /// The data structure with the tensor's data. | ||
| fn int_into_data(tensor: IntTensor<B>) -> impl Future<Output = TensorData> + Send; | ||
| fn int_into_data( | ||
| tensor: IntTensor<B>, | ||
| ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send; | ||
@@ -87,2 +90,7 @@ /// Creates a tensor from the data structure. | ||
| /// The elements at the given indices. | ||
| /// | ||
| /// # Note | ||
| /// | ||
| /// Empty slices (where start >= end) are handled at the high-level tensor API and will not | ||
| /// be passed to this method. Backend implementations do not need to handle empty slices. | ||
| fn int_slice(tensor: IntTensor<B>, slices: &[crate::Slice]) -> IntTensor<B>; | ||
@@ -89,0 +97,0 @@ |
@@ -253,2 +253,63 @@ use super::{conv, pool}; | ||
| /// Padding mode for grid sampling when coordinates are out of bounds. | ||
| /// | ||
| /// Matches PyTorch's `padding_mode` parameter in `grid_sample`. | ||
| #[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Deserialize, serde::Serialize)] | ||
| pub enum GridSamplePaddingMode { | ||
| /// Fill with zeros for out-of-bounds coordinates. | ||
| #[default] | ||
| Zeros, | ||
| /// Clamp coordinates to the border (use nearest edge value). | ||
| Border, | ||
| /// Reflect coordinates at the boundary. | ||
| Reflection, | ||
| } | ||
| /// Options for grid sampling operations. | ||
| #[derive(Debug, Clone)] | ||
| pub struct GridSampleOptions { | ||
| /// Interpolation mode (bilinear, nearest, or bicubic). | ||
| pub mode: InterpolateMode, | ||
| /// Padding mode for out-of-bounds coordinates. | ||
| pub padding_mode: GridSamplePaddingMode, | ||
| /// If `true`, grid values of -1 and 1 correspond to the corner pixels. | ||
| /// If `false`, they correspond to the corner points of the corner pixels | ||
| /// (i.e., -1 maps to -0.5 and 1 maps to size - 0.5 in pixel coordinates). | ||
| pub align_corners: bool, | ||
| } | ||
| impl Default for GridSampleOptions { | ||
| fn default() -> Self { | ||
| Self { | ||
| mode: InterpolateMode::Bilinear, | ||
| padding_mode: GridSamplePaddingMode::Zeros, | ||
| align_corners: false, | ||
| } | ||
| } | ||
| } | ||
| impl GridSampleOptions { | ||
| /// Create new grid sample options with the given interpolation mode. | ||
| /// | ||
| /// Uses default values for padding_mode (Zeros) and align_corners (false). | ||
| pub fn new(mode: InterpolateMode) -> Self { | ||
| Self { | ||
| mode, | ||
| ..Default::default() | ||
| } | ||
| } | ||
| /// Set the padding mode. | ||
| pub fn with_padding_mode(mut self, padding_mode: GridSamplePaddingMode) -> Self { | ||
| self.padding_mode = padding_mode; | ||
| self | ||
| } | ||
| /// Set align_corners. | ||
| pub fn with_align_corners(mut self, align_corners: bool) -> Self { | ||
| self.align_corners = align_corners; | ||
| self | ||
| } | ||
| } | ||
| /// Gradient computed during the backward pass for each tensor used by [interpolate](ModuleOps::interpolate). | ||
@@ -255,0 +316,0 @@ #[derive(new)] |
@@ -0,5 +1,6 @@ | ||
| use crate::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode}; | ||
| use crate::{ElementConversion, Shape, Slice, TensorMetadata, backend::Backend, ops::FloatTensor}; | ||
| use alloc::vec; | ||
| /// Default implementation of float_grid_sample_2d with bilinear interpolation and border padding | ||
| /// Reference implementation of grid_sample_2d that supports all options. | ||
| /// | ||
@@ -11,2 +12,3 @@ /// # Arguments | ||
| /// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right | ||
| /// * `options` - Grid sampling options | ||
| /// | ||
@@ -16,6 +18,28 @@ /// # Returns | ||
| /// A tensor with shape (N, C, H_out, W_out) | ||
| pub fn float_grid_sample_2d_bilinear<B: Backend>( | ||
| pub fn float_grid_sample_2d_ref<B: Backend>( | ||
| tensor: FloatTensor<B>, | ||
| grid: FloatTensor<B>, | ||
| options: GridSampleOptions, | ||
| ) -> FloatTensor<B> { | ||
| match options.mode { | ||
| InterpolateMode::Bilinear => float_grid_sample_2d_bilinear::<B>( | ||
| tensor, | ||
| grid, | ||
| options.padding_mode, | ||
| options.align_corners, | ||
| ), | ||
| _ => todo!( | ||
| "Default implementation for grid_sample_2d with {:?} unimplemented", | ||
| options.mode | ||
| ), | ||
| } | ||
| } | ||
| /// Bilinear grid sampling implementation. | ||
| fn float_grid_sample_2d_bilinear<B: Backend>( | ||
| tensor: FloatTensor<B>, | ||
| grid: FloatTensor<B>, | ||
| padding_mode: GridSamplePaddingMode, | ||
| align_corners: bool, | ||
| ) -> FloatTensor<B> { | ||
| let n = tensor.shape().dims[0]; | ||
@@ -28,9 +52,3 @@ let c = tensor.shape().dims[1]; | ||
| let x_max_half = (w_in - 1) as f64 / 2.0; | ||
| let y_max_half = (h_in - 1) as f64 / 2.0; | ||
| // Clamp grid | ||
| let grid = B::float_clamp(grid, (-1_f32).elem(), (1_f32).elem()); | ||
| // Separate x and y coordinates | ||
| // Separate x and y coordinates from grid | ||
| // shape: (N, H_out, W_out, 1) | ||
@@ -55,38 +73,115 @@ let grid_x_slice = vec![ | ||
| // Scale grid locations from [-1, 1] and [-1, 1] to [0..W_out] and [0..H_out] | ||
| let grid_x = B::float_mul_scalar(grid_x, x_max_half.elem()); | ||
| let grid_x = B::float_add_scalar(grid_x, x_max_half.elem()); | ||
| let grid_y = B::float_mul_scalar(grid_y, x_max_half.elem()); | ||
| let grid_y = B::float_add_scalar(grid_y, y_max_half.elem()); | ||
| // Convert normalized grid coordinates [-1, 1] to pixel coordinates | ||
| let w_in_f = w_in as f64; | ||
| let h_in_f = h_in as f64; | ||
| // Get low and high x locations | ||
| let (grid_x, grid_y) = if align_corners { | ||
| // align_corners=true: x_pixel = (x_norm + 1) * (width - 1) / 2 | ||
| // Maps -1 to 0 and 1 to width - 1 | ||
| let grid_x = B::float_add_scalar(grid_x, 1.0f32.elem()); | ||
| let grid_x = B::float_mul_scalar(grid_x, ((w_in_f - 1.0) / 2.0).elem()); | ||
| let grid_y = B::float_add_scalar(grid_y, 1.0f32.elem()); | ||
| let grid_y = B::float_mul_scalar(grid_y, ((h_in_f - 1.0) / 2.0).elem()); | ||
| (grid_x, grid_y) | ||
| } else { | ||
| // align_corners=false: x_pixel = (x_norm + 1) * width / 2 - 0.5 | ||
| // Maps -1 to -0.5 and 1 to width - 0.5 | ||
| let grid_x = B::float_add_scalar(grid_x, 1.0f32.elem()); | ||
| let grid_x = B::float_mul_scalar(grid_x, (w_in_f / 2.0).elem()); | ||
| let grid_x = B::float_sub_scalar(grid_x, 0.5f32.elem()); | ||
| let grid_y = B::float_add_scalar(grid_y, 1.0f32.elem()); | ||
| let grid_y = B::float_mul_scalar(grid_y, (h_in_f / 2.0).elem()); | ||
| let grid_y = B::float_sub_scalar(grid_y, 0.5f32.elem()); | ||
| (grid_x, grid_y) | ||
| }; | ||
| // Apply padding mode to coordinates | ||
| let (grid_x, grid_y) = match padding_mode { | ||
| GridSamplePaddingMode::Border => { | ||
| // Clamp coordinates to valid range [0, size-1] | ||
| let grid_x = B::float_clamp(grid_x, 0.0f32.elem(), ((w_in - 1) as f32).elem()); | ||
| let grid_y = B::float_clamp(grid_y, 0.0f32.elem(), ((h_in - 1) as f32).elem()); | ||
| (grid_x, grid_y) | ||
| } | ||
| GridSamplePaddingMode::Reflection => { | ||
| // Reflect coordinates at boundaries | ||
| // For now, use a simplified reflection that works for common cases | ||
| let grid_x = reflect_coordinates::<B>(grid_x, w_in_f); | ||
| let grid_y = reflect_coordinates::<B>(grid_y, h_in_f); | ||
| (grid_x, grid_y) | ||
| } | ||
| GridSamplePaddingMode::Zeros => { | ||
| // Keep coordinates as-is, we'll mask out-of-bounds later | ||
| (grid_x, grid_y) | ||
| } | ||
| }; | ||
| // Get floor indices for the four corners | ||
| let grid_x_floored = B::float_floor(grid_x.clone()); | ||
| let grid_x_plus_one = B::float_floor(B::float_add_scalar(grid_x.clone(), 1.elem())); | ||
| let x_indices_low = B::float_into_int(grid_x_floored.clone()); | ||
| let x_indices_high = B::float_into_int(grid_x_plus_one.clone()); | ||
| // Get low and high x locations | ||
| let grid_y_floored = B::float_floor(grid_y.clone()); | ||
| let grid_y_plus_one = B::float_floor(B::float_add_scalar(grid_y.clone(), 1.elem())); | ||
| let y_indices_low = B::float_into_int(grid_y_floored.clone()); | ||
| let y_indices_high = B::float_into_int(grid_y_plus_one.clone()); | ||
| // Clamp locations: border padding | ||
| let x_indices_low = B::int_clamp(x_indices_low, 0.elem(), ((w_in - 1) as u32).elem()); | ||
| let x_indices_high = B::int_clamp(x_indices_high, 0.elem(), ((w_in - 1) as u32).elem()); | ||
| let y_indices_low = B::int_clamp(y_indices_low, 0.elem(), ((h_in - 1) as u32).elem()); | ||
| let y_indices_high = B::int_clamp(y_indices_high, 0.elem(), ((h_in - 1) as u32).elem()); | ||
| // Compute interpolation weights (fractional part) | ||
| let x_frac = B::float_sub(grid_x.clone(), grid_x_floored.clone()); | ||
| let y_frac = B::float_sub(grid_y.clone(), grid_y_floored.clone()); | ||
| // Needs shape (N, C, H_out, W_out, W_in) for the first gather operationd | ||
| let y_indices_low = B::int_reshape(y_indices_low, Shape::new([n, 1, h_out, w_out, 1])); | ||
| let y_indices_low = B::int_expand(y_indices_low, Shape::new([n, c, h_out, w_out, w_in])); | ||
| let y_indices_high = B::int_reshape(y_indices_high, Shape::new([n, 1, h_out, w_out, 1])); | ||
| let y_indices_high = B::int_expand(y_indices_high, Shape::new([n, c, h_out, w_out, w_in])); | ||
| // Convert to integer indices | ||
| let x0 = B::float_into_int(grid_x_floored.clone()); | ||
| let y0 = B::float_into_int(grid_y_floored.clone()); | ||
| let x1 = B::float_into_int(B::float_add_scalar(grid_x_floored, 1.0f32.elem())); | ||
| let y1 = B::float_into_int(B::float_add_scalar(grid_y_floored, 1.0f32.elem())); | ||
| // Needs shape (N, C, H_out, W_out, 1) for the second gather operation | ||
| let x_indices_low = B::int_reshape(x_indices_low, Shape::new([n, 1, h_out, w_out, 1])); | ||
| let x_indices_low = B::int_expand(x_indices_low, Shape::new([n, c, h_out, w_out, 1])); | ||
| let x_indices_high = B::int_reshape(x_indices_high, Shape::new([n, 1, h_out, w_out, 1])); | ||
| let x_indices_high = B::int_expand(x_indices_high, Shape::new([n, c, h_out, w_out, 1])); | ||
| // Create masks for out-of-bounds coordinates (only used for zeros padding) | ||
| let (mask_00, mask_01, mask_10, mask_11) = if padding_mode == GridSamplePaddingMode::Zeros { | ||
| let x0_valid = B::int_greater_equal_elem(x0.clone(), 0.elem()); | ||
| let x0_valid = B::bool_and( | ||
| x0_valid, | ||
| B::int_lower_elem(x0.clone(), (w_in as i32).elem()), | ||
| ); | ||
| let x1_valid = B::int_greater_equal_elem(x1.clone(), 0.elem()); | ||
| let x1_valid = B::bool_and( | ||
| x1_valid, | ||
| B::int_lower_elem(x1.clone(), (w_in as i32).elem()), | ||
| ); | ||
| let y0_valid = B::int_greater_equal_elem(y0.clone(), 0.elem()); | ||
| let y0_valid = B::bool_and( | ||
| y0_valid, | ||
| B::int_lower_elem(y0.clone(), (h_in as i32).elem()), | ||
| ); | ||
| let y1_valid = B::int_greater_equal_elem(y1.clone(), 0.elem()); | ||
| let y1_valid = B::bool_and( | ||
| y1_valid, | ||
| B::int_lower_elem(y1.clone(), (h_in as i32).elem()), | ||
| ); | ||
| ( | ||
| Some(B::bool_and(x0_valid.clone(), y0_valid.clone())), | ||
| Some(B::bool_and(x0_valid.clone(), y1_valid.clone())), | ||
| Some(B::bool_and(x1_valid.clone(), y0_valid)), | ||
| Some(B::bool_and(x1_valid, y1_valid)), | ||
| ) | ||
| } else { | ||
| (None, None, None, None) | ||
| }; | ||
| // Clamp indices to valid range for gather | ||
| let x0_clamped = B::int_clamp(x0, 0.elem(), ((w_in - 1) as i32).elem()); | ||
| let x1_clamped = B::int_clamp(x1, 0.elem(), ((w_in - 1) as i32).elem()); | ||
| let y0_clamped = B::int_clamp(y0, 0.elem(), ((h_in - 1) as i32).elem()); | ||
| let y1_clamped = B::int_clamp(y1, 0.elem(), ((h_in - 1) as i32).elem()); | ||
| // Reshape indices for gather operation | ||
| let y0_idx = B::int_reshape(y0_clamped.clone(), Shape::new([n, 1, h_out, w_out, 1])); | ||
| let y0_idx = B::int_expand(y0_idx, Shape::new([n, c, h_out, w_out, w_in])); | ||
| let y1_idx = B::int_reshape(y1_clamped.clone(), Shape::new([n, 1, h_out, w_out, 1])); | ||
| let y1_idx = B::int_expand(y1_idx, Shape::new([n, c, h_out, w_out, w_in])); | ||
| let x0_idx = B::int_reshape(x0_clamped, Shape::new([n, 1, h_out, w_out, 1])); | ||
| let x0_idx = B::int_expand(x0_idx, Shape::new([n, c, h_out, w_out, 1])); | ||
| let x1_idx = B::int_reshape(x1_clamped, Shape::new([n, 1, h_out, w_out, 1])); | ||
| let x1_idx = B::int_expand(x1_idx, Shape::new([n, c, h_out, w_out, 1])); | ||
| // Reshape tensor for gather operation | ||
@@ -96,16 +191,16 @@ let tensor = B::float_reshape(tensor, Shape::new([n, c, h_in, 1, w_in])); | ||
| // Gather on x and y. Watch out for the shapes | ||
| let sample_00 = B::float_gather(2, tensor.clone(), y_indices_low.clone()); | ||
| let sample_00 = B::float_gather(4, sample_00, x_indices_low.clone()); | ||
| // Gather samples from the four corners | ||
| let sample_00 = B::float_gather(2, tensor.clone(), y0_idx.clone()); | ||
| let sample_00 = B::float_gather(4, sample_00, x0_idx.clone()); | ||
| let sample_01 = B::float_gather(2, tensor.clone(), y_indices_high.clone()); | ||
| let sample_01 = B::float_gather(4, sample_01, x_indices_low.clone()); | ||
| let sample_01 = B::float_gather(2, tensor.clone(), y1_idx.clone()); | ||
| let sample_01 = B::float_gather(4, sample_01, x0_idx.clone()); | ||
| let sample_10 = B::float_gather(2, tensor.clone(), y_indices_low.clone()); | ||
| let sample_10 = B::float_gather(4, sample_10, x_indices_high.clone()); | ||
| let sample_10 = B::float_gather(2, tensor.clone(), y0_idx); | ||
| let sample_10 = B::float_gather(4, sample_10, x1_idx.clone()); | ||
| let sample_11 = B::float_gather(2, tensor, y_indices_high); | ||
| let sample_11 = B::float_gather(4, sample_11, x_indices_high); | ||
| let sample_11 = B::float_gather(2, tensor, y1_idx); | ||
| let sample_11 = B::float_gather(4, sample_11, x1_idx); | ||
| // Reshape to (N, C, H_out, W_out) for multiplying with weights | ||
| // Reshape samples to (N, C, H_out, W_out) | ||
| let sample_00 = B::float_reshape(sample_00, Shape::new([n, c, h_out, w_out])); | ||
@@ -116,30 +211,71 @@ let sample_01 = B::float_reshape(sample_01, Shape::new([n, c, h_out, w_out])); | ||
| // Weights for bilinear interp | ||
| let weight_00 = B::float_mul( | ||
| B::float_sub(grid_x_plus_one.clone(), grid_x.clone()), | ||
| B::float_sub(grid_y_plus_one.clone(), grid_y.clone()), | ||
| ); | ||
| let weight_10 = B::float_mul( | ||
| B::float_sub(grid_x.clone(), grid_x_floored.clone()), | ||
| B::float_sub(grid_y_plus_one.clone(), grid_y.clone()), | ||
| ); | ||
| let weight_01 = B::float_mul( | ||
| B::float_sub(grid_x_plus_one.clone(), grid_x.clone()), | ||
| B::float_sub(grid_y.clone(), grid_y_floored.clone()), | ||
| ); | ||
| let weight_11 = B::float_mul( | ||
| B::float_sub(grid_x.clone(), grid_x_floored), | ||
| B::float_sub(grid_y.clone(), grid_y_floored), | ||
| ); | ||
| // Apply masks for zeros padding (set out-of-bounds samples to 0) | ||
| let (sample_00, sample_01, sample_10, sample_11) = | ||
| if padding_mode == GridSamplePaddingMode::Zeros { | ||
| let mask_00 = mask_00.unwrap(); | ||
| let mask_01 = mask_01.unwrap(); | ||
| let mask_10 = mask_10.unwrap(); | ||
| let mask_11 = mask_11.unwrap(); | ||
| // Bilinear interp | ||
| let sample_0 = B::float_add( | ||
| B::float_mul(sample_00, weight_00), | ||
| B::float_mul(sample_01, weight_01), | ||
| ); | ||
| let sample_1 = B::float_add( | ||
| B::float_mul(sample_10, weight_10), | ||
| B::float_mul(sample_11, weight_11), | ||
| ); | ||
| B::float_add(sample_0, sample_1) | ||
| let mask_00_inv = B::bool_not(mask_00); | ||
| let mask_00_inv = B::bool_reshape(mask_00_inv, Shape::new([n, 1, h_out, w_out])); | ||
| let mask_00_inv = B::bool_expand(mask_00_inv, Shape::new([n, c, h_out, w_out])); | ||
| let mask_01_inv = B::bool_not(mask_01); | ||
| let mask_01_inv = B::bool_reshape(mask_01_inv, Shape::new([n, 1, h_out, w_out])); | ||
| let mask_01_inv = B::bool_expand(mask_01_inv, Shape::new([n, c, h_out, w_out])); | ||
| let mask_10_inv = B::bool_not(mask_10); | ||
| let mask_10_inv = B::bool_reshape(mask_10_inv, Shape::new([n, 1, h_out, w_out])); | ||
| let mask_10_inv = B::bool_expand(mask_10_inv, Shape::new([n, c, h_out, w_out])); | ||
| let mask_11_inv = B::bool_not(mask_11); | ||
| let mask_11_inv = B::bool_reshape(mask_11_inv, Shape::new([n, 1, h_out, w_out])); | ||
| let mask_11_inv = B::bool_expand(mask_11_inv, Shape::new([n, c, h_out, w_out])); | ||
| ( | ||
| B::float_mask_fill(sample_00, mask_00_inv, 0.0f32.elem()), | ||
| B::float_mask_fill(sample_01, mask_01_inv, 0.0f32.elem()), | ||
| B::float_mask_fill(sample_10, mask_10_inv, 0.0f32.elem()), | ||
| B::float_mask_fill(sample_11, mask_11_inv, 0.0f32.elem()), | ||
| ) | ||
| } else { | ||
| (sample_00, sample_01, sample_10, sample_11) | ||
| }; | ||
| // Compute bilinear interpolation weights | ||
| let one_minus_x = B::float_neg(x_frac.clone()); | ||
| let one_minus_x = B::float_add_scalar(one_minus_x, 1.0f32.elem()); | ||
| let one_minus_y = B::float_neg(y_frac.clone()); | ||
| let one_minus_y = B::float_add_scalar(one_minus_y, 1.0f32.elem()); | ||
| let weight_00 = B::float_mul(one_minus_x.clone(), one_minus_y.clone()); | ||
| let weight_01 = B::float_mul(one_minus_x.clone(), y_frac.clone()); | ||
| let weight_10 = B::float_mul(x_frac.clone(), one_minus_y); | ||
| let weight_11 = B::float_mul(x_frac, y_frac); | ||
| // Bilinear interpolation | ||
| let result = B::float_mul(sample_00, weight_00); | ||
| let result = B::float_add(result, B::float_mul(sample_01, weight_01)); | ||
| let result = B::float_add(result, B::float_mul(sample_10, weight_10)); | ||
| B::float_add(result, B::float_mul(sample_11, weight_11)) | ||
| } | ||
| /// Reflect coordinates at boundaries for reflection padding. | ||
| /// | ||
| /// Uses the formula: reflected = 2 * bound - x for out-of-bounds coordinates. | ||
| fn reflect_coordinates<B: Backend>(coords: FloatTensor<B>, size: f64) -> FloatTensor<B> { | ||
| // Simple reflection: clamp to [0, size-1] after reflecting | ||
| // For values < 0: reflect at 0 -> -x | ||
| // For values >= size: reflect at size-1 -> 2*(size-1) - x | ||
| // This is a simplified implementation - full reflection would handle multiple reflections | ||
| let max_val = (size - 1.0) as f32; | ||
| // First handle negative values by taking absolute value | ||
| let coords = B::float_abs(coords); | ||
| // Then handle values > max by reflecting: 2*max - x | ||
| // But we need to detect which values need this | ||
| // Simplified: just clamp for now, proper reflection is complex | ||
| B::float_clamp(coords, 0.0f32.elem(), max_val.elem()) | ||
| } |
@@ -0,7 +1,7 @@ | ||
| use crate::quantization::QuantScheme; | ||
| use alloc::vec::Vec; | ||
| use cubecl_quant::scheme::QuantScheme; | ||
| use crate::{ | ||
| Device, Shape, TensorData, TensorMetadata, TensorPrimitive, | ||
| backend::Backend, | ||
| backend::{Backend, ExecutionError}, | ||
| quantization::{ | ||
@@ -192,3 +192,5 @@ Calibration, QTensorPrimitive, QuantPropagation, QuantizationParametersPrimitive, | ||
| /// The data structure with the tensor's data. | ||
| fn q_into_data(tensor: QuantizedTensor<B>) -> impl Future<Output = TensorData> + Send; | ||
| fn q_into_data( | ||
| tensor: QuantizedTensor<B>, | ||
| ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send; | ||
@@ -195,0 +197,0 @@ /// Detaches a tensor from the computation graph. |
+15
-11
| use super::cat::cat_with_slice_assign; | ||
| use super::grid_sample::float_grid_sample_2d_bilinear; | ||
| use super::grid_sample::float_grid_sample_2d_ref; | ||
| use super::repeat_dim::repeat_with_slice_assign; | ||
| use super::{BoolTensor, Device, FloatElem, FloatTensor, IntElem, IntTensor}; | ||
| use crate::ops::InterpolateMode; | ||
| use crate::backend::ExecutionError; | ||
| use crate::ops::GridSampleOptions; | ||
| use crate::{Distribution, ElementConversion, Float, TensorData, backend::Backend, tensor::Shape}; | ||
@@ -102,3 +103,5 @@ use crate::{FloatDType, TensorMetadata, TensorPrimitive}; | ||
| /// The data structure with the tensor's data. | ||
| fn float_into_data(tensor: FloatTensor<B>) -> impl Future<Output = TensorData> + Send; | ||
| fn float_into_data( | ||
| tensor: FloatTensor<B>, | ||
| ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send; | ||
@@ -493,2 +496,7 @@ /// Gets the device of the tensor. | ||
| /// The selected elements in a new tensor. | ||
| /// | ||
| /// # Note | ||
| /// | ||
| /// Empty slices (where start >= end) are handled at the high-level tensor API and will not | ||
| /// be passed to this method. Backend implementations do not need to handle empty slices. | ||
| fn float_slice(tensor: FloatTensor<B>, slices: &[crate::Slice]) -> FloatTensor<B>; | ||
@@ -1520,4 +1528,3 @@ | ||
| /// | ||
| /// Interpolation is bilinear. | ||
| /// Padding is border: out of bounds locations will be clamped to the nearest border | ||
| /// # Arguments | ||
| /// | ||
@@ -1527,3 +1534,3 @@ /// * `tensor` - The tensor being sampled from, shape (N, C, H_in, W_in) | ||
| /// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right | ||
| /// * `method` - How to interpolate between samples | ||
| /// * `options` - Grid sampling options (mode, padding_mode, align_corners) | ||
| /// | ||
@@ -1536,8 +1543,5 @@ /// # Returns | ||
| grid: FloatTensor<B>, | ||
| method: InterpolateMode, | ||
| options: GridSampleOptions, | ||
| ) -> FloatTensor<B> { | ||
| match method { | ||
| InterpolateMode::Bilinear => float_grid_sample_2d_bilinear::<B>(tensor, grid), | ||
| _ => todo!("Default implementation for grid_sample_2d with {method:?} unimplemented"), | ||
| } | ||
| float_grid_sample_2d_ref::<B>(tensor, grid, options) | ||
| } | ||
@@ -1544,0 +1548,0 @@ |
@@ -5,3 +5,6 @@ use alloc::vec::Vec; | ||
| use super::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}; | ||
| use crate::{TensorData, backend::Backend}; | ||
| use crate::{ | ||
| TensorData, | ||
| backend::{Backend, ExecutionError}, | ||
| }; | ||
@@ -23,3 +26,3 @@ #[derive(Default)] | ||
| /// Contains all [data](TensorData) related to a [transaction](TransactionPrimitive). | ||
| pub struct TransactionPrimitiveResult { | ||
| pub struct TransactionPrimitiveData { | ||
| /// Float tensor data. | ||
@@ -39,6 +42,6 @@ pub read_floats: Vec<TensorData>, | ||
| /// Executes a [transaction](TransactionPrimitive) and return its | ||
| /// [result](TransactionPrimitiveResult). | ||
| /// [data](TransactionPrimitiveData). | ||
| fn tr_execute( | ||
| transaction: TransactionPrimitive<B>, | ||
| ) -> impl Future<Output = TransactionPrimitiveResult> + Send { | ||
| ) -> impl Future<Output = Result<TransactionPrimitiveData, ExecutionError>> + Send { | ||
| async move { | ||
@@ -51,15 +54,15 @@ let mut floats = Vec::new(); | ||
| for t in transaction.read_floats { | ||
| floats.push(B::float_into_data(t).await); | ||
| floats.push(B::float_into_data(t).await?); | ||
| } | ||
| for t in transaction.read_qfloats { | ||
| qfloats.push(B::q_into_data(t).await); | ||
| qfloats.push(B::q_into_data(t).await?); | ||
| } | ||
| for t in transaction.read_ints { | ||
| ints.push(B::int_into_data(t).await); | ||
| ints.push(B::int_into_data(t).await?); | ||
| } | ||
| for t in transaction.read_bools { | ||
| bools.push(B::bool_into_data(t).await); | ||
| bools.push(B::bool_into_data(t).await?); | ||
| } | ||
| TransactionPrimitiveResult { | ||
| Ok(TransactionPrimitiveData { | ||
| read_floats: floats, | ||
@@ -69,5 +72,5 @@ read_qfloats: qfloats, | ||
| read_bools: bools, | ||
| } | ||
| }) | ||
| } | ||
| } | ||
| } |
@@ -1,4 +0,2 @@ | ||
| mod bytes; | ||
| mod calibration; | ||
| mod data; | ||
| mod parameters; | ||
@@ -8,7 +6,6 @@ mod primitive; | ||
| pub use bytes::*; | ||
| pub use burn_std::quantization::QuantizedBytes; | ||
| pub use calibration::*; | ||
| pub use data::*; | ||
| pub use parameters::*; | ||
| pub use primitive::*; | ||
| pub use scheme::*; |
| use crate::{DType, Shape, Tensor, backend::Backend}; | ||
| use alloc::vec::Vec; | ||
| pub use burn_std::quantization::QParams; | ||
| /// The tensor quantization parameters. | ||
| pub type QuantizationParameters<B> = QParams<Tensor<B, 1>>; | ||
| /// The quantization tensor data parameters. | ||
| #[derive(Clone, Debug)] | ||
| pub struct QParams<S> { | ||
| /// The scaling factor. | ||
| pub scales: S, | ||
| } | ||
| /// The quantization parameters primitive. | ||
@@ -15,0 +10,0 @@ /// |
@@ -1,5 +0,3 @@ | ||
| use cubecl_quant::scheme::QuantScheme; | ||
| use crate::quantization::{QuantAcc, QuantPropagation, QuantScheme}; | ||
| use crate::quantization::{QuantAcc, QuantPropagation}; | ||
| /// Quantized tensor primitive. | ||
@@ -6,0 +4,0 @@ pub trait QTensorPrimitive { |
| // We re-export those types. | ||
| pub use cubecl_quant::scheme::{ | ||
| pub use burn_std::quantization::{ | ||
| BlockSize, QuantLevel, QuantMode, QuantParam, QuantScheme, QuantStore, QuantValue, | ||
@@ -4,0 +4,0 @@ }; |
+4
-1
@@ -292,2 +292,5 @@ mod activation; | ||
| burn_tensor::testgen_padding!(); | ||
| // test transaction | ||
| burn_tensor::testgen_transaction!(); | ||
| }; | ||
@@ -411,3 +414,3 @@ } | ||
| use cubecl_quant::scheme::QuantLevel; | ||
| use crate::quantization::QuantLevel; | ||
@@ -414,0 +417,0 @@ use crate::{ |
@@ -94,3 +94,3 @@ #[burn_tensor_testgen::testgen(add)] | ||
| TestBackend::sync(&tensor2.device()); | ||
| TestBackend::sync(&tensor2.device()).unwrap(); | ||
@@ -113,3 +113,3 @@ let output = tensor1 + tensor2; | ||
| TestBackend::sync(&tensor2.device()); | ||
| TestBackend::sync(&tensor2.device()).unwrap(); | ||
@@ -116,0 +116,0 @@ let output = tensor2 + tensor1; |
@@ -310,3 +310,3 @@ #[burn_tensor_testgen::testgen(aggregation)] | ||
| let tensor = TestTensorInt::arange(0..9, &Default::default()).float(); | ||
| TestBackend::sync(&tensor.device()); | ||
| TestBackend::sync(&tensor.device()).unwrap(); | ||
@@ -324,3 +324,3 @@ let output = (tensor.reshape([3, 3]) + 2); | ||
| let tensor = tensor.reshape([3, 3]); | ||
| TestBackend::sync(&tensor.device()); | ||
| TestBackend::sync(&tensor.device()).unwrap(); | ||
@@ -337,3 +337,3 @@ let output = (tensor.swap_dims(0, 1) + 2); | ||
| let tensor = TestTensorInt::arange(0..9, &Default::default()).float(); | ||
| TestBackend::sync(&tensor.device()); | ||
| TestBackend::sync(&tensor.device()).unwrap(); | ||
@@ -353,6 +353,6 @@ let output = (tensor.reshape([1, 3, 3]) + 2); | ||
| let tensor_2 = tensor_2.reshape([1, 2, 1]); | ||
| TestBackend::sync(&tensor_1.device()); | ||
| TestBackend::sync(&tensor_1.device()).unwrap(); | ||
| let output = (tensor_1 + tensor_2.clone()).sum_dim(2) + tensor_2; | ||
| TestBackend::sync(&output.device()); | ||
| TestBackend::sync(&output.device()).unwrap(); | ||
| let expected = TensorData::from([[[56.0], [77.0]]]); | ||
@@ -373,3 +373,3 @@ | ||
| let tensor_2 = tensor_2.swap_dims(1, 2); | ||
| TestBackend::sync(&tensor_1.device()); | ||
| TestBackend::sync(&tensor_1.device()).unwrap(); | ||
@@ -394,3 +394,3 @@ let output = (tensor_1 + tensor_2).sum_dim(2); | ||
| let tensor_2 = tensor_2.swap_dims(1, 2); | ||
| TestBackend::sync(&tensor_1.device()); | ||
| TestBackend::sync(&tensor_1.device()).unwrap(); | ||
@@ -415,3 +415,3 @@ let output = (tensor_3 + tensor_1 + tensor_2).sum_dim(2); | ||
| let tensor_2 = tensor_2.swap_dims(1, 2); | ||
| TestBackend::sync(&tensor_1.device()); | ||
| TestBackend::sync(&tensor_1.device()).unwrap(); | ||
@@ -435,3 +435,3 @@ let output = (tensor_3 + tensor_1 + tensor_2).sum_dim(1); | ||
| TestBackend::sync(&tensor_1.device()); | ||
| TestBackend::sync(&tensor_1.device()).unwrap(); | ||
| let sum = tensor_2.clone().sum_dim(0); | ||
@@ -441,3 +441,3 @@ let sum = sum.sum_dim(1); | ||
| TestBackend::sync(&tensor_1.device()); | ||
| TestBackend::sync(&tensor_1.device()).unwrap(); | ||
@@ -462,3 +462,3 @@ let tmp = sum.clone() + 2; | ||
| TestBackend::sync(&tensor_1.device()); | ||
| TestBackend::sync(&tensor_1.device()).unwrap(); | ||
@@ -465,0 +465,0 @@ let reshaped = tensor_1.reshape([1, 4, 4]); |
@@ -6,7 +6,14 @@ #[burn_tensor_testgen::testgen(grid_sample)] | ||
| Tensor, TensorData, Tolerance, | ||
| ops::{FloatElem, InterpolateMode}, | ||
| ops::{FloatElem, GridSampleOptions, GridSamplePaddingMode, InterpolateMode}, | ||
| }; | ||
| /// Tests grid_sample_2d with default options (align_corners=false, zeros padding). | ||
| /// | ||
| /// For a 3x3 input with grid coordinates: | ||
| /// - (0.0, 0.0) maps to pixel (1.0, 1.0) -> center pixel = 4.0 | ||
| /// - (-1.0, 0.25) maps to pixel (-0.5, 1.375) -> partially out of bounds | ||
| /// - (1.0, 1.0) maps to pixel (2.5, 2.5) -> corner, partially out of bounds | ||
| /// - (0.2, -0.8) maps to pixel (1.3, 0.3) -> interpolates around center-top | ||
| #[test] | ||
| fn should_grid_sample_2d() { | ||
| fn should_grid_sample_2d_default() { | ||
| let device = Default::default(); | ||
@@ -22,4 +29,32 @@ let tensor = TestTensor::<4>::from_floats( | ||
| let output = tensor.grid_sample_2d(grid, InterpolateMode::Bilinear); | ||
| let output = tensor.grid_sample_2d(grid, GridSampleOptions::default()); | ||
| // Expected values computed with PyTorch grid_sample(align_corners=False, padding_mode='zeros') | ||
| let expected = TensorData::from([[[[4.0, 2.0625], [2.0, 1.04]]]]); | ||
| output | ||
| .to_data() | ||
| .assert_approx_eq::<FloatElem<TestBackend>>(&expected, Tolerance::default()); | ||
| } | ||
| /// Tests grid_sample_2d with align_corners=true and border padding. | ||
| /// | ||
| /// This is the original Burn semantics before the API change. | ||
| #[test] | ||
| fn should_grid_sample_2d_align_corners_border() { | ||
| let device = Default::default(); | ||
| let tensor = TestTensor::<4>::from_floats( | ||
| [[[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]], | ||
| &device, | ||
| ); | ||
| let grid = TestTensor::<4>::from_floats( | ||
| [[[[0.0, 0.0], [-1.0, 0.25]], [[1.0, 1.0], [0.2, -0.8]]]], | ||
| &device, | ||
| ); | ||
| let options = GridSampleOptions::new(InterpolateMode::Bilinear) | ||
| .with_padding_mode(GridSamplePaddingMode::Border) | ||
| .with_align_corners(true); | ||
| let output = tensor.grid_sample_2d(grid, options); | ||
| // Expected values computed with PyTorch grid_sample(align_corners=True, padding_mode='border') | ||
| let expected = TensorData::from([[[[4.0, 3.75], [8.0, 1.8]]]]); | ||
@@ -31,4 +66,6 @@ output | ||
| /// Tests out-of-bounds grid coordinates with zeros padding. | ||
| /// Grid coordinate (0.0, -2.0) maps to pixel (1.0, -2.5) which is completely out of bounds. | ||
| #[test] | ||
| fn should_pad_border_grid_sample_2d() { | ||
| fn should_pad_zeros_grid_sample_2d() { | ||
| let device = Default::default(); | ||
@@ -41,6 +78,6 @@ let tensor = TestTensor::<4>::from_floats( | ||
| let output = tensor.grid_sample_2d(grid, InterpolateMode::Bilinear); | ||
| let output = tensor.grid_sample_2d(grid, GridSampleOptions::default()); | ||
| // Should clamp to nearest: 1.0 | ||
| let expected = TensorData::from([[[[1.0]]]]); | ||
| // With zeros padding, out-of-bounds samples return 0 | ||
| let expected = TensorData::from([[[[0.0]]]]); | ||
| output | ||
@@ -51,4 +88,5 @@ .to_data() | ||
| /// Tests out-of-bounds grid coordinates with border padding. | ||
| #[test] | ||
| fn should_handle_inf_grid_sample_2d() { | ||
| fn should_pad_border_grid_sample_2d() { | ||
| let device = Default::default(); | ||
@@ -59,8 +97,11 @@ let tensor = TestTensor::<4>::from_floats( | ||
| ); | ||
| let grid = TestTensor::<4>::from_floats([[[[0.0, 1.0f32 / 0.0]]]], &device); | ||
| let grid = TestTensor::<4>::from_floats([[[[0.0, -2.0]]]], &device); | ||
| let output = tensor.grid_sample_2d(grid, InterpolateMode::Bilinear); | ||
| let options = GridSampleOptions::new(InterpolateMode::Bilinear) | ||
| .with_padding_mode(GridSamplePaddingMode::Border); | ||
| let output = tensor.grid_sample_2d(grid, options); | ||
| // Should clamp to nearest, regardless of inf grid value: 7.0 | ||
| let expected = TensorData::from([[[[7.0]]]]); | ||
| // With border padding, out-of-bounds coordinates are clamped to border | ||
| // Grid (0.0, -2.0) with align_corners=false: pixel (1.0, -2.5) -> clamped to (1.0, 0.0) = 1.0 | ||
| let expected = TensorData::from([[[[1.0]]]]); | ||
| output | ||
@@ -67,0 +108,0 @@ .to_data() |
@@ -302,3 +302,3 @@ #[burn_tensor_testgen::testgen(matmul)] | ||
| TestBackend::sync(&device); | ||
| TestBackend::sync(&device).unwrap(); | ||
| let weight = weight.transpose(); | ||
@@ -305,0 +305,0 @@ let out = signal.matmul(weight) + 5; |
@@ -85,2 +85,3 @@ mod abs; | ||
| mod topk; | ||
| mod transaction; | ||
| mod transpose; | ||
@@ -87,0 +88,0 @@ mod tri; |
@@ -421,24 +421,42 @@ #[burn_tensor_testgen::testgen(slice)] | ||
| #[test] | ||
| #[should_panic] | ||
| fn should_panic_when_slice_is_desc() { | ||
| fn should_support_descending_slice_as_empty() { | ||
| // Like PyTorch, x[3:1] should return an empty tensor, not panic | ||
| let data = TensorData::from([0.0, 1.0, 2.0]); | ||
| let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default()); | ||
| let tensor = TestTensor::<1>::from_data(data, &Default::default()); | ||
| let output = tensor.slice(s![2..1]); | ||
| output.into_data().assert_eq(&data, false); | ||
| // Should produce an empty tensor with shape [0] | ||
| assert_eq!(output.dims(), [0]); | ||
| } | ||
| #[test] | ||
| #[should_panic] | ||
| fn should_panic_when_slice_is_equal() { | ||
| fn should_support_empty_slice() { | ||
| // ONNX models can have empty slices where start == end | ||
| // This should produce a tensor with size 0 in that dimension | ||
| let data = TensorData::from([0.0, 1.0, 2.0]); | ||
| let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default()); | ||
| let tensor = TestTensor::<1>::from_data(data, &Default::default()); | ||
| let output = tensor.slice([1..1]); | ||
| output.into_data().assert_eq(&data, false); | ||
| // Should produce an empty tensor with shape [0] | ||
| assert_eq!(output.dims(), [0]); | ||
| } | ||
| #[test] | ||
| fn should_support_empty_slice_2d() { | ||
| // Test empty slice on 2D tensor | ||
| let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); | ||
| let tensor = TestTensor::<2>::from_data(data, &Default::default()); | ||
| // Empty slice on first dimension | ||
| let output = tensor.clone().slice([1..1, 0..3]); | ||
| assert_eq!(output.dims(), [0, 3]); | ||
| // Empty slice on second dimension | ||
| let output = tensor.slice([0..2, 2..2]); | ||
| assert_eq!(output.dims(), [2, 0]); | ||
| } | ||
| #[test] | ||
| fn test_slice_with_positive_step() { | ||
@@ -445,0 +463,0 @@ let device = Default::default(); |
| use alloc::vec::Vec; | ||
| use crate::Shape; | ||
| use crate::indexing::AsIndex; | ||
| use core::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}; | ||
| /// Trait for slice arguments that can be converted into an array of slices. | ||
| /// This allows the `slice` method to accept both single slices (from `s![..]`) | ||
| /// and arrays of slices (from `s![.., ..]` or `[0..5, 1..3]`). | ||
| pub trait SliceArg<const D2: usize> { | ||
| /// Convert to an array of slices with clamping to shape dimensions | ||
| fn into_slices(self, shape: Shape) -> [Slice; D2]; | ||
| } | ||
| impl<const D2: usize, T> SliceArg<D2> for [T; D2] | ||
| where | ||
| T: Into<Slice>, | ||
| { | ||
| fn into_slices(self, shape: Shape) -> [Slice; D2] { | ||
| self.into_iter() | ||
| .enumerate() | ||
| .map(|(i, s)| { | ||
| let slice: Slice = s.into(); | ||
| // Apply shape clamping by converting to range and back | ||
| let clamped_range = slice.to_range(shape[i]); | ||
| Slice::new( | ||
| clamped_range.start as isize, | ||
| Some(clamped_range.end as isize), | ||
| slice.step(), | ||
| ) | ||
| }) | ||
| .collect::<Vec<_>>() | ||
| .try_into() | ||
| .unwrap() | ||
| } | ||
| } | ||
| impl<T> SliceArg<1> for T | ||
| where | ||
| T: Into<Slice>, | ||
| { | ||
| fn into_slices(self, shape: Shape) -> [Slice; 1] { | ||
| let slice: Slice = self.into(); | ||
| let clamped_range = slice.to_range(shape[0]); | ||
| [Slice::new( | ||
| clamped_range.start as isize, | ||
| Some(clamped_range.end as isize), | ||
| slice.step(), | ||
| )] | ||
| } | ||
| } | ||
| /// Slice argument constructor for tensor indexing. | ||
| /// | ||
| /// The `s![]` macro is used to create multi-dimensional slice specifications for tensors. | ||
| /// It converts various range syntax forms into a `&[Slice]` that can be used with | ||
| /// `tensor.slice()` and `tensor.slice_assign()` operations. | ||
| /// | ||
| /// # Syntax Overview | ||
| /// | ||
| /// ## Basic Forms | ||
| /// | ||
| /// * **`s![index]`** - Index a single element (produces a subview with that axis removed) | ||
| /// * **`s![range]`** - Slice a range of elements | ||
| /// * **`s![range;step]`** - Slice a range with a custom step | ||
| /// * **`s![dim1, dim2, ...]`** - Multiple dimensions, each can be any of the above forms | ||
| /// | ||
| /// ## Range Types | ||
| /// | ||
| /// All standard Rust range types are supported: | ||
| /// * **`a..b`** - From `a` (inclusive) to `b` (exclusive) | ||
| /// * **`a..=b`** - From `a` to `b` (both inclusive) | ||
| /// * **`a..`** - From `a` to the end | ||
| /// * **`..b`** - From the beginning to `b` (exclusive) | ||
| /// * **`..=b`** - From the beginning to `b` (inclusive) | ||
| /// * **`..`** - The full range (all elements) | ||
| /// | ||
| /// ## Negative Indices | ||
| /// | ||
| /// Negative indices count from the end of the axis: | ||
| /// * **`-1`** refers to the last element | ||
| /// * **`-2`** refers to the second-to-last element | ||
| /// * And so on... | ||
| /// | ||
| /// This works in all range forms: `s![-3..-1]`, `s![-2..]`, `s![..-1]` | ||
| /// | ||
| /// ## Step Syntax | ||
| /// | ||
| /// Steps control the stride between selected elements: | ||
| /// * **`;step`** after a range specifies the step | ||
| /// * **Positive steps** select every nth element going forward | ||
| /// * **Negative steps** select every nth element going backward | ||
| /// * Default step is `1` when not specified | ||
| /// * Step cannot be `0` | ||
| /// | ||
| /// ### Negative Step Behavior | ||
| /// | ||
| /// With negative steps, the range bounds still specify *which* elements to include, | ||
| /// but the traversal order is reversed: | ||
| /// | ||
| /// * `s![0..5;-1]` selects indices `[4, 3, 2, 1, 0]` (not `[0, 1, 2, 3, 4]`) | ||
| /// * `s![2..8;-2]` selects indices `[7, 5, 3]` (starting from 7, going backward by 2) | ||
| /// * `s![..;-1]` reverses the entire axis | ||
| /// | ||
| /// This matches the semantics of NumPy and the ndarray crate. | ||
| /// | ||
| /// # Examples | ||
| /// | ||
| /// ## Basic Slicing | ||
| /// | ||
| /// ```rust,ignore | ||
| /// use burn_tensor::{Tensor, s}; | ||
| /// | ||
| /// # fn example<B: Backend>(tensor: Tensor<B, 3>) { | ||
| /// // Select rows 0-5 (exclusive) | ||
| /// let subset = tensor.slice(s![0..5, .., ..]); | ||
| /// | ||
| /// // Select the last row | ||
| /// let last_row = tensor.slice(s![-1, .., ..]); | ||
| /// | ||
| /// // Select columns 2, 3, 4 | ||
| /// let cols = tensor.slice(s![.., 2..5, ..]); | ||
| /// | ||
| /// // Select a single element at position [1, 2, 3] | ||
| /// let element = tensor.slice(s![1, 2, 3]); | ||
| /// # } | ||
| /// ``` | ||
| /// | ||
| /// ## Slicing with Steps | ||
| /// | ||
| /// ```rust,ignore | ||
| /// use burn_tensor::{Tensor, s}; | ||
| /// | ||
| /// # fn example<B: Backend>(tensor: Tensor<B, 2>) { | ||
| /// // Select every 2nd row | ||
| /// let even_rows = tensor.slice(s![0..10;2, ..]); | ||
| /// | ||
| /// // Select every 3rd column | ||
| /// let cols = tensor.slice(s![.., 0..9;3]); | ||
| /// | ||
| /// // Select every 2nd element in reverse order | ||
| /// let reversed_even = tensor.slice(s![10..0;-2, ..]); | ||
| /// # } | ||
| /// ``` | ||
| /// | ||
| /// ## Reversing Dimensions | ||
| /// | ||
| /// ```rust,ignore | ||
| /// use burn_tensor::{Tensor, s}; | ||
| /// | ||
| /// # fn example<B: Backend>(tensor: Tensor<B, 2>) { | ||
| /// // Reverse the first dimension | ||
| /// let reversed = tensor.slice(s![..;-1, ..]); | ||
| /// | ||
| /// // Reverse both dimensions | ||
| /// let fully_reversed = tensor.slice(s![..;-1, ..;-1]); | ||
| /// | ||
| /// // Reverse a specific range | ||
| /// let range_reversed = tensor.slice(s![2..8;-1, ..]); | ||
| /// # } | ||
| /// ``` | ||
| /// | ||
| /// ## Complex Multi-dimensional Slicing | ||
| /// | ||
| /// ```rust,ignore | ||
| /// use burn_tensor::{Tensor, s}; | ||
| /// | ||
| /// # fn example<B: Backend>(tensor: Tensor<B, 4>) { | ||
| /// // Mix of different slice types | ||
| /// let complex = tensor.slice(s![ | ||
| /// 0..10;2, // Every 2nd element from 0 to 10 | ||
| /// .., // All elements in dimension 1 | ||
| /// 5..15;-3, // Every 3rd element from 14 down to 5 | ||
| /// -1 // Last element in dimension 3 | ||
| /// ]); | ||
| /// | ||
| /// // Using inclusive ranges | ||
| /// let inclusive = tensor.slice(s![2..=5, 1..=3, .., ..]); | ||
| /// | ||
| /// // Negative indices with steps | ||
| /// let from_end = tensor.slice(s![-5..-1;2, .., .., ..]); | ||
| /// # } | ||
| /// ``` | ||
| /// | ||
| /// ## Slice Assignment | ||
| /// | ||
| /// ```rust,ignore | ||
| /// use burn_tensor::{Tensor, s}; | ||
| /// | ||
| /// # fn example<B: Backend>(tensor: Tensor<B, 2>, values: Tensor<B, 2>) { | ||
| /// // Assign to every 2nd row | ||
| /// let tensor = tensor.slice_assign(s![0..10;2, ..], values); | ||
| /// | ||
| /// // Assign to a reversed slice | ||
| /// let tensor = tensor.slice_assign(s![..;-1, 0..5], values); | ||
| /// # } | ||
| /// ``` | ||
| #[macro_export] | ||
| macro_rules! s { | ||
| // Empty - should not happen | ||
| [] => { | ||
| compile_error!("Empty slice specification") | ||
| }; | ||
| // Single expression with step | ||
| [$range:expr; $step:expr] => { | ||
| { | ||
| #[allow(clippy::reversed_empty_ranges)] | ||
| { | ||
| $crate::Slice::from_range_stepped($range, $step) | ||
| } | ||
| } | ||
| }; | ||
| // Single expression without step (no comma after) | ||
| [$range:expr] => { | ||
| { | ||
| #[allow(clippy::reversed_empty_ranges)] | ||
| { | ||
| $crate::Slice::from($range) | ||
| } | ||
| } | ||
| }; | ||
| // Two or more expressions with first having step | ||
| [$range:expr; $step:expr, $($rest:tt)*] => { | ||
| { | ||
| #[allow(clippy::reversed_empty_ranges)] | ||
| { | ||
| $crate::s!(@internal [$crate::Slice::from_range_stepped($range, $step)] $($rest)*) | ||
| } | ||
| } | ||
| }; | ||
| // Two or more expressions with first not having step | ||
| [$range:expr, $($rest:tt)*] => { | ||
| { | ||
| #[allow(clippy::reversed_empty_ranges)] | ||
| { | ||
| $crate::s!(@internal [$crate::Slice::from($range)] $($rest)*) | ||
| } | ||
| } | ||
| }; | ||
| // Internal: finished parsing | ||
| (@internal [$($acc:expr),*]) => { | ||
| [$($acc),*] | ||
| }; | ||
| // Internal: parse range with step followed by comma | ||
| (@internal [$($acc:expr),*] $range:expr; $step:expr, $($rest:tt)*) => { | ||
| $crate::s!(@internal [$($acc,)* $crate::Slice::from_range_stepped($range, $step as isize)] $($rest)*) | ||
| }; | ||
| // Internal: parse range with step at end | ||
| (@internal [$($acc:expr),*] $range:expr; $step:expr) => { | ||
| $crate::s!(@internal [$($acc,)* $crate::Slice::from_range_stepped($range, $step as isize)]) | ||
| }; | ||
| // Internal: parse range without step followed by comma | ||
| (@internal [$($acc:expr),*] $range:expr, $($rest:tt)*) => { | ||
| $crate::s!(@internal [$($acc,)* $crate::Slice::from($range)] $($rest)*) | ||
| }; | ||
| // Internal: parse range without step at end | ||
| (@internal [$($acc:expr),*] $range:expr) => { | ||
| $crate::s!(@internal [$($acc,)* $crate::Slice::from($range)]) | ||
| }; | ||
| } | ||
| /// A slice specification for a single tensor dimension. | ||
| /// | ||
| /// This struct represents a range with an optional step, used for advanced indexing | ||
| /// operations on tensors. It is typically created using the [`s!`] macro rather than | ||
| /// constructed directly. | ||
| /// | ||
| /// # Fields | ||
| /// | ||
| /// * `start` - The starting index (inclusive). Negative values count from the end. | ||
| /// * `end` - The ending index (exclusive). `None` means to the end of the dimension. | ||
| /// * `step` - The stride between elements. Must be non-zero. | ||
| /// | ||
| /// # Index Interpretation | ||
| /// | ||
| /// - **Positive indices**: Count from the beginning (0-based) | ||
| /// - **Negative indices**: Count from the end (-1 is the last element) | ||
| /// - **Bounds checking**: Indices are clamped to valid ranges | ||
| /// | ||
| /// # Step Behavior | ||
| /// | ||
| /// - **Positive step**: Traverse forward through the range | ||
| /// - **Negative step**: Traverse backward through the range | ||
| /// - **Step size**: Determines how many elements to skip | ||
| /// | ||
| /// # Examples | ||
| /// | ||
| /// While you typically use the [`s!`] macro, you can also construct slices directly: | ||
| /// | ||
| /// ```rust,ignore | ||
| /// use burn_tensor::Slice; | ||
| /// | ||
| /// // Equivalent to s![2..8] | ||
| /// let slice1 = Slice::new(2, Some(8), 1); | ||
| /// | ||
| /// // Equivalent to s![0..10;2] | ||
| /// let slice2 = Slice::new(0, Some(10), 2); | ||
| /// | ||
| /// // Equivalent to s![..;-1] (reverse) | ||
| /// let slice3 = Slice::new(0, None, -1); | ||
| /// ``` | ||
| /// | ||
| /// See also the [`s!`] macro for the preferred way to create slices. | ||
| #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)] | ||
| pub struct Slice { | ||
| /// Slice start index. | ||
| pub start: isize, | ||
| /// Slice end index (exclusive). | ||
| pub end: Option<isize>, | ||
| /// Step between elements (default: 1). | ||
| pub step: isize, | ||
| } | ||
| impl Default for Slice { | ||
| fn default() -> Self { | ||
| Self::full() | ||
| } | ||
| } | ||
| impl Slice { | ||
| /// Creates a new slice with start, end, and step | ||
| pub const fn new(start: isize, end: Option<isize>, step: isize) -> Self { | ||
| assert!(step != 0, "Step cannot be zero"); | ||
| Self { start, end, step } | ||
| } | ||
| /// Creates a slice that represents the full range. | ||
| pub const fn full() -> Self { | ||
| Self::new(0, None, 1) | ||
| } | ||
| /// Creates a slice that represents a single index | ||
| pub fn index(idx: isize) -> Self { | ||
| Self { | ||
| start: idx, | ||
| end: handle_signed_inclusive_end(idx), | ||
| step: 1, | ||
| } | ||
| } | ||
| /// Creates a slice with a custom step | ||
| pub fn with_step(start: isize, end: Option<isize>, step: isize) -> Self { | ||
| assert!(step != 0, "Step cannot be zero"); | ||
| Self { start, end, step } | ||
| } | ||
| /// Creates a slice from a range with a specified step | ||
| pub fn from_range_stepped<R: Into<Slice>>(range: R, step: isize) -> Self { | ||
| assert!(step != 0, "Step cannot be zero"); | ||
| let mut slice = range.into(); | ||
| slice.step = step; | ||
| slice | ||
| } | ||
| /// Returns the step of the slice | ||
| pub fn step(&self) -> isize { | ||
| self.step | ||
| } | ||
| /// Returns the range for this slice given a dimension size | ||
| pub fn range(&self, size: usize) -> Range<usize> { | ||
| self.to_range(size) | ||
| } | ||
| /// Convert this slice to a range for a dimension of the given size. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `size` - The size of the dimension to slice. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A `Range<usize>` representing the slice bounds. | ||
| pub fn to_range(&self, size: usize) -> Range<usize> { | ||
| // Always return a valid range with start <= end | ||
| // The step information will be handled separately | ||
| let start = convert_signed_index(self.start, size); | ||
| let end = match self.end { | ||
| Some(end) => convert_signed_index(end, size), | ||
| None => size, | ||
| }; | ||
| start..end | ||
| } | ||
| /// Converts the slice into a range and step tuple | ||
| pub fn to_range_and_step(&self, size: usize) -> (Range<usize>, isize) { | ||
| let range = self.to_range(size); | ||
| (range, self.step) | ||
| } | ||
| /// Returns true if the step is negative | ||
| pub fn is_reversed(&self) -> bool { | ||
| self.step < 0 | ||
| } | ||
| /// Calculates the output size for this slice operation | ||
| pub fn output_size(&self, dim_size: usize) -> usize { | ||
| let range = self.to_range(dim_size); | ||
| let len = range.end - range.start; | ||
| if self.step.unsigned_abs() == 1 { | ||
| len | ||
| } else { | ||
| len.div_ceil(self.step.unsigned_abs()) | ||
| } | ||
| } | ||
| } | ||
| fn convert_signed_index(index: isize, size: usize) -> usize { | ||
| if index < 0 { | ||
| (size as isize + index).max(0) as usize | ||
| } else { | ||
| (index as usize).min(size) | ||
| } | ||
| } | ||
| fn handle_signed_inclusive_end(end: isize) -> Option<isize> { | ||
| match end { | ||
| -1 => None, | ||
| end => Some(end + 1), | ||
| } | ||
| } | ||
| impl<I: AsIndex> From<Range<I>> for Slice { | ||
| fn from(r: Range<I>) -> Self { | ||
| Self { | ||
| start: r.start.index(), | ||
| end: Some(r.end.index()), | ||
| step: 1, | ||
| } | ||
| } | ||
| } | ||
| impl<I: AsIndex + Copy> From<RangeInclusive<I>> for Slice { | ||
| fn from(r: RangeInclusive<I>) -> Self { | ||
| Self { | ||
| start: (*r.start()).index(), | ||
| end: handle_signed_inclusive_end((*r.end()).index()), | ||
| step: 1, | ||
| } | ||
| } | ||
| } | ||
| impl<I: AsIndex> From<RangeFrom<I>> for Slice { | ||
| fn from(r: RangeFrom<I>) -> Self { | ||
| Self { | ||
| start: r.start.index(), | ||
| end: None, | ||
| step: 1, | ||
| } | ||
| } | ||
| } | ||
| impl<I: AsIndex> From<RangeTo<I>> for Slice { | ||
| fn from(r: RangeTo<I>) -> Self { | ||
| Self { | ||
| start: 0, | ||
| end: Some(r.end.index()), | ||
| step: 1, | ||
| } | ||
| } | ||
| } | ||
| impl<I: AsIndex> From<RangeToInclusive<I>> for Slice { | ||
| fn from(r: RangeToInclusive<I>) -> Self { | ||
| Self { | ||
| start: 0, | ||
| end: handle_signed_inclusive_end(r.end.index()), | ||
| step: 1, | ||
| } | ||
| } | ||
| } | ||
| impl From<RangeFull> for Slice { | ||
| fn from(_: RangeFull) -> Self { | ||
| Self { | ||
| start: 0, | ||
| end: None, | ||
| step: 1, | ||
| } | ||
| } | ||
| } | ||
| impl From<usize> for Slice { | ||
| fn from(i: usize) -> Self { | ||
| Slice::index(i as isize) | ||
| } | ||
| } | ||
| impl From<isize> for Slice { | ||
| fn from(i: isize) -> Self { | ||
| Slice::index(i) | ||
| } | ||
| } | ||
| impl From<i32> for Slice { | ||
| fn from(i: i32) -> Self { | ||
| Slice::index(i as isize) | ||
| } | ||
| } | ||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
| #[test] | ||
| fn test_slice_output_size() { | ||
| // Test the output_size method directly | ||
| assert_eq!(Slice::new(0, Some(10), 1).output_size(10), 10); | ||
| assert_eq!(Slice::new(0, Some(10), 2).output_size(10), 5); | ||
| assert_eq!(Slice::new(0, Some(10), 3).output_size(10), 4); // ceil(10/3) | ||
| assert_eq!(Slice::new(0, Some(10), -1).output_size(10), 10); | ||
| assert_eq!(Slice::new(0, Some(10), -2).output_size(10), 5); | ||
| assert_eq!(Slice::new(2, Some(8), -3).output_size(10), 2); // ceil(6/3) | ||
| assert_eq!(Slice::new(5, Some(5), 1).output_size(10), 0); // empty range | ||
| } | ||
| } |
-1233
| use core::f32; | ||
| use alloc::boxed::Box; | ||
| use alloc::format; | ||
| use alloc::string::String; | ||
| use alloc::vec::Vec; | ||
| use bytemuck::{AnyBitPattern, CheckedBitPattern, Zeroable, cast_mut, checked::CheckedCastError}; | ||
| use cubecl_quant::scheme::QuantScheme; | ||
| use half::{bf16, f16}; | ||
| use num_traits::{Float, ToPrimitive}; | ||
| use crate::{ | ||
| DType, Distribution, Element, ElementConversion, | ||
| quantization::{QuantValue, QuantizedBytes}, | ||
| tensor::Bytes, | ||
| }; | ||
| use rand::RngCore; | ||
| use super::quantization::{QuantLevel, QuantMode}; | ||
| /// Data structure for tensors. | ||
| #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] | ||
| pub struct TensorData { | ||
| /// The values of the tensor (as bytes). | ||
| pub bytes: Bytes, | ||
| /// The shape of the tensor. | ||
| pub shape: Vec<usize>, | ||
| /// The data type of the tensor. | ||
| pub dtype: DType, | ||
| } | ||
| impl TensorData { | ||
| /// Creates a new tensor data structure. | ||
| pub fn new<E: Element, S: Into<Vec<usize>>>(value: Vec<E>, shape: S) -> Self { | ||
| // Ensure shape is valid | ||
| let shape = shape.into(); | ||
| Self::check_data_len(&value, &shape); | ||
| Self { | ||
| bytes: Bytes::from_elems(value), | ||
| shape, | ||
| dtype: E::dtype(), | ||
| } | ||
| } | ||
| /// Creates a new quantized tensor data structure. | ||
| pub fn quantized<E: Element, S: Into<Vec<usize>>>( | ||
| value: Vec<E>, | ||
| shape: S, | ||
| scheme: QuantScheme, | ||
| qparams: &[f32], | ||
| ) -> Self { | ||
| let shape = shape.into(); | ||
| Self::check_data_len(&value, &shape); | ||
| let q_bytes = QuantizedBytes::new(value, scheme, qparams); | ||
| Self { | ||
| bytes: q_bytes.bytes, | ||
| shape, | ||
| dtype: DType::QFloat(q_bytes.scheme), | ||
| } | ||
| } | ||
| /// Creates a new tensor data structure from raw bytes. | ||
| pub fn from_bytes<S: Into<Vec<usize>>>(bytes: Bytes, shape: S, dtype: DType) -> Self { | ||
| Self { | ||
| bytes, | ||
| shape: shape.into(), | ||
| dtype, | ||
| } | ||
| } | ||
| /// Creates a new tensor data structure from raw bytes stored in a vector. | ||
| /// | ||
| /// Prefer [`TensorData::new`] or [`TensorData::quantized`] over this method unless you are | ||
| /// certain that the bytes representation is valid. | ||
| pub fn from_bytes_vec<S: Into<Vec<usize>>>(bytes: Vec<u8>, shape: S, dtype: DType) -> Self { | ||
| Self { | ||
| bytes: Bytes::from_bytes_vec(bytes), | ||
| shape: shape.into(), | ||
| dtype, | ||
| } | ||
| } | ||
| // Check that the input vector contains a correct number of elements | ||
| fn check_data_len<E: Element>(data: &[E], shape: &Vec<usize>) { | ||
| let expected_data_len = Self::numel(shape); | ||
| let num_data = data.len(); | ||
| assert_eq!( | ||
| expected_data_len, num_data, | ||
| "Shape {shape:?} is invalid for input of size {num_data:?}", | ||
| ); | ||
| } | ||
| /// Returns the immutable slice view of the tensor data. | ||
| pub fn as_slice<E: Element>(&self) -> Result<&[E], DataError> { | ||
| if E::dtype() == self.dtype { | ||
| match E::dtype() { | ||
| // The only way to create a bool `TensorData` with invalid values is by unsafely modifying | ||
| // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool | ||
| // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow. | ||
| DType::Bool => { | ||
| let slice = bytemuck::checked::try_cast_slice::<_, u8>(&self.bytes) | ||
| .map_err(DataError::CastError)?; | ||
| Ok(unsafe { core::mem::transmute::<&[u8], &[E]>(slice) }) | ||
| } | ||
| _ => bytemuck::checked::try_cast_slice(&self.bytes).map_err(DataError::CastError), | ||
| } | ||
| } else { | ||
| Err(DataError::TypeMismatch(format!( | ||
| "Invalid target element type (expected {:?}, got {:?})", | ||
| self.dtype, | ||
| E::dtype() | ||
| ))) | ||
| } | ||
| } | ||
| /// Returns the mutable slice view of the tensor data. | ||
| /// | ||
| /// # Panics | ||
| /// If the target element type is different from the stored element type. | ||
| pub fn as_mut_slice<E: Element>(&mut self) -> Result<&mut [E], DataError> { | ||
| if E::dtype() == self.dtype { | ||
| match E::dtype() { | ||
| // The only way to create a bool `TensorData` with invalid values is by unsafely modifying | ||
| // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool | ||
| // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow. | ||
| DType::Bool => { | ||
| let slice = bytemuck::checked::try_cast_slice_mut::<_, u8>(&mut self.bytes) | ||
| .map_err(DataError::CastError)?; | ||
| Ok(unsafe { core::mem::transmute::<&mut [u8], &mut [E]>(slice) }) | ||
| } | ||
| _ => bytemuck::checked::try_cast_slice_mut(&mut self.bytes) | ||
| .map_err(DataError::CastError), | ||
| } | ||
| } else { | ||
| Err(DataError::TypeMismatch(format!( | ||
| "Invalid target element type (expected {:?}, got {:?})", | ||
| self.dtype, | ||
| E::dtype() | ||
| ))) | ||
| } | ||
| } | ||
| /// Returns the tensor data as a vector of scalar values. | ||
| pub fn to_vec<E: Element>(&self) -> Result<Vec<E>, DataError> { | ||
| Ok(self.as_slice()?.to_vec()) | ||
| } | ||
| /// Returns the tensor data as a vector of scalar values. | ||
| pub fn into_vec<E: Element>(self) -> Result<Vec<E>, DataError> { | ||
| // This means we cannot call `into_vec` for QFloat | ||
| if E::dtype() != self.dtype { | ||
| return Err(DataError::TypeMismatch(format!( | ||
| "Invalid target element type (expected {:?}, got {:?})", | ||
| self.dtype, | ||
| E::dtype() | ||
| ))); | ||
| } | ||
| match E::dtype() { | ||
| // The only way to create a bool `TensorData` with invalid values is by unsafely modifying | ||
| // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool | ||
| // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow. | ||
| DType::Bool => { | ||
| let vec = self.into_vec_unchecked::<u8>()?; | ||
| Ok(unsafe { core::mem::transmute::<Vec<u8>, Vec<E>>(vec) }) | ||
| } | ||
| _ => self.into_vec_unchecked(), | ||
| } | ||
| } | ||
| /// Returns the tensor data as a vector of scalar values. Does not check dtype. | ||
| fn into_vec_unchecked<E: Element>(self) -> Result<Vec<E>, DataError> { | ||
| let mut me = self; | ||
| me.bytes = match me.bytes.try_into_vec::<E>() { | ||
| Ok(elems) => return Ok(elems), | ||
| Err(bytes) => bytes, | ||
| }; | ||
| // The bytes might have been deserialized and allocated with a different align. | ||
| // In that case, we have to memcopy the data into a new vector, more suitably allocated | ||
| Ok(bytemuck::checked::try_cast_slice(me.as_bytes()) | ||
| .map_err(DataError::CastError)? | ||
| .to_vec()) | ||
| } | ||
| /// Returns an iterator over the values of the tensor data. | ||
| pub fn iter<E: Element>(&self) -> Box<dyn Iterator<Item = E> + '_> { | ||
| if E::dtype() == self.dtype { | ||
| Box::new(bytemuck::checked::cast_slice(&self.bytes).iter().copied()) | ||
| } else { | ||
| match self.dtype { | ||
| DType::I8 => Box::new( | ||
| bytemuck::checked::cast_slice(&self.bytes) | ||
| .iter() | ||
| .map(|e: &i8| e.elem::<E>()), | ||
| ), | ||
| DType::I16 => Box::new( | ||
| bytemuck::checked::cast_slice(&self.bytes) | ||
| .iter() | ||
| .map(|e: &i16| e.elem::<E>()), | ||
| ), | ||
| DType::I32 => Box::new( | ||
| bytemuck::checked::cast_slice(&self.bytes) | ||
| .iter() | ||
| .map(|e: &i32| e.elem::<E>()), | ||
| ), | ||
| DType::I64 => Box::new( | ||
| bytemuck::checked::cast_slice(&self.bytes) | ||
| .iter() | ||
| .map(|e: &i64| e.elem::<E>()), | ||
| ), | ||
| DType::U8 => Box::new(self.bytes.iter().map(|e| e.elem::<E>())), | ||
| DType::U16 => Box::new( | ||
| bytemuck::checked::cast_slice(&self.bytes) | ||
| .iter() | ||
| .map(|e: &u16| e.elem::<E>()), | ||
| ), | ||
| DType::U32 => Box::new( | ||
| bytemuck::checked::cast_slice(&self.bytes) | ||
| .iter() | ||
| .map(|e: &u32| e.elem::<E>()), | ||
| ), | ||
| DType::U64 => Box::new( | ||
| bytemuck::checked::cast_slice(&self.bytes) | ||
| .iter() | ||
| .map(|e: &u64| e.elem::<E>()), | ||
| ), | ||
| DType::BF16 => Box::new( | ||
| bytemuck::checked::cast_slice(&self.bytes) | ||
| .iter() | ||
| .map(|e: &bf16| e.elem::<E>()), | ||
| ), | ||
| DType::F16 => Box::new( | ||
| bytemuck::checked::cast_slice(&self.bytes) | ||
| .iter() | ||
| .map(|e: &f16| e.elem::<E>()), | ||
| ), | ||
| DType::F32 | DType::Flex32 => Box::new( | ||
| bytemuck::checked::cast_slice(&self.bytes) | ||
| .iter() | ||
| .map(|e: &f32| e.elem::<E>()), | ||
| ), | ||
| DType::F64 => Box::new( | ||
| bytemuck::checked::cast_slice(&self.bytes) | ||
| .iter() | ||
| .map(|e: &f64| e.elem::<E>()), | ||
| ), | ||
| // bool is a byte value equal to either 0 or 1 | ||
| DType::Bool => Box::new(self.bytes.iter().map(|e| e.elem::<E>())), | ||
| DType::QFloat(scheme) => match scheme { | ||
| QuantScheme { | ||
| level: QuantLevel::Tensor | QuantLevel::Block(_), | ||
| mode: QuantMode::Symmetric, | ||
| value: | ||
| QuantValue::Q8F | ||
| | QuantValue::Q8S | ||
| // Represent sub-byte values as i8 | ||
| | QuantValue::Q4F | ||
| | QuantValue::Q4S | ||
| | QuantValue::Q2F | ||
| | QuantValue::Q2S, | ||
| .. | ||
| } => { | ||
| // Quantized int8 values | ||
| let q_bytes = QuantizedBytes { | ||
| bytes: self.bytes.clone(), | ||
| scheme, | ||
| num_elements: self.num_elements(), | ||
| }; | ||
| let (values, _) = q_bytes.into_vec_i8(); | ||
| Box::new( | ||
| values | ||
| .iter() | ||
| .map(|e: &i8| e.elem::<E>()) | ||
| .collect::<Vec<_>>() | ||
| .into_iter(), | ||
| ) | ||
| } | ||
| QuantScheme { | ||
| level: QuantLevel::Tensor | QuantLevel::Block(_), | ||
| mode: QuantMode::Symmetric, | ||
| value: | ||
| QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1, | ||
| .. | ||
| } => { | ||
| unimplemented!("Not yet implemented for iteration"); | ||
| } | ||
| }, | ||
| } | ||
| } | ||
| } | ||
| /// Returns the rank (the number of dimensions). | ||
| pub fn rank(&self) -> usize { | ||
| self.shape.len() | ||
| } | ||
| /// Returns the total number of elements of the tensor data. | ||
| pub fn num_elements(&self) -> usize { | ||
| Self::numel(&self.shape) | ||
| } | ||
| fn numel(shape: &[usize]) -> usize { | ||
| shape.iter().product() | ||
| } | ||
| /// Populates the data with random values. | ||
| pub fn random<E: Element, R: RngCore, S: Into<Vec<usize>>>( | ||
| shape: S, | ||
| distribution: Distribution, | ||
| rng: &mut R, | ||
| ) -> Self { | ||
| let shape = shape.into(); | ||
| let num_elements = Self::numel(&shape); | ||
| let mut data = Vec::with_capacity(num_elements); | ||
| for _ in 0..num_elements { | ||
| data.push(E::random(distribution, rng)); | ||
| } | ||
| TensorData::new(data, shape) | ||
| } | ||
| /// Populates the data with zeros. | ||
| pub fn zeros<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData { | ||
| let shape = shape.into(); | ||
| let num_elements = Self::numel(&shape); | ||
| let mut data = Vec::<E>::with_capacity(num_elements); | ||
| for _ in 0..num_elements { | ||
| data.push(0.elem()); | ||
| } | ||
| TensorData::new(data, shape) | ||
| } | ||
| /// Populates the data with ones. | ||
| pub fn ones<E: Element, S: Into<Vec<usize>>>(shape: S) -> TensorData { | ||
| let shape = shape.into(); | ||
| let num_elements = Self::numel(&shape); | ||
| let mut data = Vec::<E>::with_capacity(num_elements); | ||
| for _ in 0..num_elements { | ||
| data.push(1.elem()); | ||
| } | ||
| TensorData::new(data, shape) | ||
| } | ||
| /// Populates the data with the given value | ||
| pub fn full<E: Element, S: Into<Vec<usize>>>(shape: S, fill_value: E) -> TensorData { | ||
| let shape = shape.into(); | ||
| let num_elements = Self::numel(&shape); | ||
| let mut data = Vec::<E>::with_capacity(num_elements); | ||
| for _ in 0..num_elements { | ||
| data.push(fill_value) | ||
| } | ||
| TensorData::new(data, shape) | ||
| } | ||
| pub(crate) fn full_dtype<E: Element, S: Into<Vec<usize>>>( | ||
| shape: S, | ||
| fill_value: E, | ||
| dtype: DType, | ||
| ) -> TensorData { | ||
| match dtype { | ||
| DType::F64 => Self::full::<f64, _>(shape, fill_value.elem()), | ||
| DType::F32 | DType::Flex32 => Self::full::<f32, _>(shape, fill_value.elem()), | ||
| DType::F16 => Self::full::<f16, _>(shape, fill_value.elem()), | ||
| DType::BF16 => Self::full::<bf16, _>(shape, fill_value.elem()), | ||
| DType::I64 => Self::full::<i64, _>(shape, fill_value.elem()), | ||
| DType::I32 => Self::full::<i32, _>(shape, fill_value.elem()), | ||
| DType::I16 => Self::full::<i16, _>(shape, fill_value.elem()), | ||
| DType::I8 => Self::full::<i8, _>(shape, fill_value.elem()), | ||
| DType::U64 => Self::full::<u64, _>(shape, fill_value.elem()), | ||
| DType::U32 => Self::full::<u32, _>(shape, fill_value.elem()), | ||
| DType::U16 => Self::full::<u16, _>(shape, fill_value.elem()), | ||
| DType::U8 => Self::full::<u8, _>(shape, fill_value.elem()), | ||
| DType::Bool => Self::full::<bool, _>(shape, fill_value.elem()), | ||
| DType::QFloat(_) => unreachable!(), | ||
| } | ||
| } | ||
| /// Converts the data to a different element type. | ||
| pub fn convert<E: Element>(self) -> Self { | ||
| self.convert_dtype(E::dtype()) | ||
| } | ||
| /// Converts the data to a different element type. | ||
| pub fn convert_dtype(self, dtype: DType) -> Self { | ||
| if dtype == self.dtype { | ||
| self | ||
| } else if dtype.size() == self.dtype.size() | ||
| && !matches!(self.dtype, DType::Bool | DType::QFloat(_)) | ||
| && !matches!(dtype, DType::Bool | DType::QFloat(_)) | ||
| { | ||
| match self.dtype { | ||
| DType::F64 => self.convert_inplace_dtype::<f64>(dtype), | ||
| DType::F32 | DType::Flex32 => self.convert_inplace_dtype::<f32>(dtype), | ||
| DType::F16 => self.convert_inplace_dtype::<f16>(dtype), | ||
| DType::BF16 => self.convert_inplace_dtype::<bf16>(dtype), | ||
| DType::I64 => self.convert_inplace_dtype::<i64>(dtype), | ||
| DType::I32 => self.convert_inplace_dtype::<i32>(dtype), | ||
| DType::I16 => self.convert_inplace_dtype::<i16>(dtype), | ||
| DType::I8 => self.convert_inplace_dtype::<i8>(dtype), | ||
| DType::U64 => self.convert_inplace_dtype::<u64>(dtype), | ||
| DType::U32 => self.convert_inplace_dtype::<u32>(dtype), | ||
| DType::U16 => self.convert_inplace_dtype::<u16>(dtype), | ||
| DType::U8 => self.convert_inplace_dtype::<u8>(dtype), | ||
| DType::Bool | DType::QFloat(_) => unreachable!(), | ||
| } | ||
| } else { | ||
| match self.dtype { | ||
| DType::F64 => self.convert_clone_dtype::<f64>(dtype), | ||
| DType::F32 | DType::Flex32 => self.convert_clone_dtype::<f32>(dtype), | ||
| DType::F16 => self.convert_clone_dtype::<f16>(dtype), | ||
| DType::BF16 => self.convert_clone_dtype::<bf16>(dtype), | ||
| DType::I64 => self.convert_clone_dtype::<i64>(dtype), | ||
| DType::I32 => self.convert_clone_dtype::<i32>(dtype), | ||
| DType::I16 => self.convert_clone_dtype::<i16>(dtype), | ||
| DType::I8 => self.convert_clone_dtype::<i8>(dtype), | ||
| DType::U64 => self.convert_clone_dtype::<u64>(dtype), | ||
| DType::U32 => self.convert_clone_dtype::<u32>(dtype), | ||
| DType::U16 => self.convert_clone_dtype::<u16>(dtype), | ||
| DType::U8 => self.convert_clone_dtype::<u8>(dtype), | ||
| DType::Bool => self.convert_clone_dtype::<bool>(dtype), | ||
| DType::QFloat(_) => unreachable!(), | ||
| } | ||
| } | ||
| } | ||
| fn convert_inplace_dtype<Current: Element + AnyBitPattern>(self, dtype: DType) -> Self { | ||
| match dtype { | ||
| DType::F64 => self.convert_inplace::<Current, f64>(), | ||
| DType::F32 | DType::Flex32 => self.convert_inplace::<Current, f32>(), | ||
| DType::F16 => self.convert_inplace::<Current, f16>(), | ||
| DType::BF16 => self.convert_inplace::<Current, bf16>(), | ||
| DType::I64 => self.convert_inplace::<Current, i64>(), | ||
| DType::I32 => self.convert_inplace::<Current, i32>(), | ||
| DType::I16 => self.convert_inplace::<Current, i16>(), | ||
| DType::I8 => self.convert_inplace::<Current, i8>(), | ||
| DType::U64 => self.convert_inplace::<Current, u64>(), | ||
| DType::U32 => self.convert_inplace::<Current, u32>(), | ||
| DType::U16 => self.convert_inplace::<Current, u16>(), | ||
| DType::U8 => self.convert_inplace::<Current, u8>(), | ||
| DType::Bool | DType::QFloat(_) => unreachable!(), | ||
| } | ||
| } | ||
| fn convert_inplace<Current: Element + AnyBitPattern, Target: Element + AnyBitPattern>( | ||
| mut self, | ||
| ) -> Self { | ||
| for x in bytemuck::cast_slice_mut::<_, Current>(&mut self.bytes) { | ||
| let t: Target = x.elem(); | ||
| let x = cast_mut::<_, Target>(x); | ||
| *x = t; | ||
| } | ||
| self.dtype = Target::dtype(); | ||
| self | ||
| } | ||
| fn convert_clone_dtype<Current: Element + CheckedBitPattern>(self, dtype: DType) -> Self { | ||
| match dtype { | ||
| DType::F64 => self.convert_clone::<Current, f64>(), | ||
| DType::F32 | DType::Flex32 => self.convert_clone::<Current, f32>(), | ||
| DType::F16 => self.convert_clone::<Current, f16>(), | ||
| DType::BF16 => self.convert_clone::<Current, bf16>(), | ||
| DType::I64 => self.convert_clone::<Current, i64>(), | ||
| DType::I32 => self.convert_clone::<Current, i32>(), | ||
| DType::I16 => self.convert_clone::<Current, i16>(), | ||
| DType::I8 => self.convert_clone::<Current, i8>(), | ||
| DType::U64 => self.convert_clone::<Current, u64>(), | ||
| DType::U32 => self.convert_clone::<Current, u32>(), | ||
| DType::U16 => self.convert_clone::<Current, u16>(), | ||
| DType::U8 => self.convert_clone::<Current, u8>(), | ||
| DType::Bool => self.convert_clone::<Current, bool>(), | ||
| DType::QFloat(_) => unreachable!(), | ||
| } | ||
| } | ||
| fn convert_clone<Current: Element + CheckedBitPattern, Target: Element + Zeroable>( | ||
| self, | ||
| ) -> Self { | ||
| let this = bytemuck::checked::cast_slice::<_, Current>(&self.bytes); | ||
| let mut out: Vec<Target> = ::alloc::vec![Zeroable::zeroed(); self.num_elements()]; | ||
| for (x, out) in this.iter().zip(&mut out) { | ||
| *out = x.elem(); | ||
| } | ||
| Self::new(out, self.shape) | ||
| } | ||
| /// Returns the data as a slice of bytes. | ||
| pub fn as_bytes(&self) -> &[u8] { | ||
| &self.bytes | ||
| } | ||
| /// Returns the bytes representation of the data. | ||
| pub fn into_bytes(self) -> Bytes { | ||
| self.bytes | ||
| } | ||
| /// Asserts the data is equal to another data. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `other` - The other data. | ||
| /// * `strict` - If true, the data types must the be same. | ||
| /// Otherwise, the comparison is done in the current data type. | ||
| /// | ||
| /// # Panics | ||
| /// | ||
| /// Panics if the data is not equal. | ||
| #[track_caller] | ||
| pub fn assert_eq(&self, other: &Self, strict: bool) { | ||
| if strict { | ||
| assert_eq!( | ||
| self.dtype, other.dtype, | ||
| "Data types differ ({:?} != {:?})", | ||
| self.dtype, other.dtype | ||
| ); | ||
| } | ||
| match self.dtype { | ||
| DType::F64 => self.assert_eq_elem::<f64>(other), | ||
| DType::F32 | DType::Flex32 => self.assert_eq_elem::<f32>(other), | ||
| DType::F16 => self.assert_eq_elem::<f16>(other), | ||
| DType::BF16 => self.assert_eq_elem::<bf16>(other), | ||
| DType::I64 => self.assert_eq_elem::<i64>(other), | ||
| DType::I32 => self.assert_eq_elem::<i32>(other), | ||
| DType::I16 => self.assert_eq_elem::<i16>(other), | ||
| DType::I8 => self.assert_eq_elem::<i8>(other), | ||
| DType::U64 => self.assert_eq_elem::<u64>(other), | ||
| DType::U32 => self.assert_eq_elem::<u32>(other), | ||
| DType::U16 => self.assert_eq_elem::<u16>(other), | ||
| DType::U8 => self.assert_eq_elem::<u8>(other), | ||
| DType::Bool => self.assert_eq_elem::<bool>(other), | ||
| DType::QFloat(q) => { | ||
| // Strict or not, it doesn't make sense to compare quantized data to not quantized data for equality | ||
| let q_other = if let DType::QFloat(q_other) = other.dtype { | ||
| q_other | ||
| } else { | ||
| panic!("Quantized data differs from other not quantized data") | ||
| }; | ||
| // Data equality mostly depends on input quantization type, but we also check level | ||
| if q.value == q_other.value && q.level == q_other.level { | ||
| self.assert_eq_elem::<i8>(other) | ||
| } else { | ||
| panic!("Quantization schemes differ ({q:?} != {q_other:?})") | ||
| } | ||
| } | ||
| } | ||
| } | ||
| #[track_caller] | ||
| fn assert_eq_elem<E: Element>(&self, other: &Self) { | ||
| let mut message = String::new(); | ||
| if self.shape != other.shape { | ||
| message += format!( | ||
| "\n => Shape is different: {:?} != {:?}", | ||
| self.shape, other.shape | ||
| ) | ||
| .as_str(); | ||
| } | ||
| let mut num_diff = 0; | ||
| let max_num_diff = 5; | ||
| for (i, (a, b)) in self.iter::<E>().zip(other.iter::<E>()).enumerate() { | ||
| if a.cmp(&b).is_ne() { | ||
| // Only print the first 5 different values. | ||
| if num_diff < max_num_diff { | ||
| message += format!("\n => Position {i}: {a} != {b}").as_str(); | ||
| } | ||
| num_diff += 1; | ||
| } | ||
| } | ||
| if num_diff >= max_num_diff { | ||
| message += format!("\n{} more errors...", num_diff - max_num_diff).as_str(); | ||
| } | ||
| if !message.is_empty() { | ||
| panic!("Tensors are not eq:{message}"); | ||
| } | ||
| } | ||
| /// Asserts the data is approximately equal to another data. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `other` - The other data. | ||
| /// * `tolerance` - The tolerance of the comparison. | ||
| /// | ||
| /// # Panics | ||
| /// | ||
| /// Panics if the data is not approximately equal. | ||
| #[track_caller] | ||
| pub fn assert_approx_eq<F: Float + Element>(&self, other: &Self, tolerance: Tolerance<F>) { | ||
| let mut message = String::new(); | ||
| if self.shape != other.shape { | ||
| message += format!( | ||
| "\n => Shape is different: {:?} != {:?}", | ||
| self.shape, other.shape | ||
| ) | ||
| .as_str(); | ||
| } | ||
| let iter = self.iter::<F>().zip(other.iter::<F>()); | ||
| let mut num_diff = 0; | ||
| let max_num_diff = 5; | ||
| for (i, (a, b)) in iter.enumerate() { | ||
| //if they are both nan, then they are equally nan | ||
| let both_nan = a.is_nan() && b.is_nan(); | ||
| //this works for both infinities | ||
| let both_inf = | ||
| a.is_infinite() && b.is_infinite() && ((a > F::zero()) == (b > F::zero())); | ||
| if both_nan || both_inf { | ||
| continue; | ||
| } | ||
| if !tolerance.approx_eq(F::from(a).unwrap(), F::from(b).unwrap()) { | ||
| // Only print the first 5 different values. | ||
| if num_diff < max_num_diff { | ||
| let diff_abs = ToPrimitive::to_f64(&(a - b).abs()).unwrap(); | ||
| let max = F::max(a.abs(), b.abs()); | ||
| let diff_rel = diff_abs / ToPrimitive::to_f64(&max).unwrap(); | ||
| let tol_rel = ToPrimitive::to_f64(&tolerance.relative).unwrap(); | ||
| let tol_abs = ToPrimitive::to_f64(&tolerance.absolute).unwrap(); | ||
| message += format!( | ||
| "\n => Position {i}: {a} != {b}\n diff (rel = {diff_rel:+.2e}, abs = {diff_abs:+.2e}), tol (rel = {tol_rel:+.2e}, abs = {tol_abs:+.2e})" | ||
| ) | ||
| .as_str(); | ||
| } | ||
| num_diff += 1; | ||
| } | ||
| } | ||
| if num_diff >= max_num_diff { | ||
| message += format!("\n{} more errors...", num_diff - 5).as_str(); | ||
| } | ||
| if !message.is_empty() { | ||
| panic!("Tensors are not approx eq:{message}"); | ||
| } | ||
| } | ||
| /// Asserts each value is within a given range. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `range` - The range. | ||
| /// | ||
| /// # Panics | ||
| /// | ||
| /// If any value is not within the half-open range bounded inclusively below | ||
| /// and exclusively above (`start..end`). | ||
| pub fn assert_within_range<E: Element>(&self, range: core::ops::Range<E>) { | ||
| for elem in self.iter::<E>() { | ||
| if elem.cmp(&range.start).is_lt() || elem.cmp(&range.end).is_ge() { | ||
| panic!("Element ({elem:?}) is not within range {range:?}"); | ||
| } | ||
| } | ||
| } | ||
| /// Asserts each value is within a given inclusive range. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `range` - The range. | ||
| /// | ||
| /// # Panics | ||
| /// | ||
| /// If any value is not within the half-open range bounded inclusively (`start..=end`). | ||
| pub fn assert_within_range_inclusive<E: Element>(&self, range: core::ops::RangeInclusive<E>) { | ||
| let start = range.start(); | ||
| let end = range.end(); | ||
| for elem in self.iter::<E>() { | ||
| if elem.cmp(start).is_lt() || elem.cmp(end).is_gt() { | ||
| panic!("Element ({elem:?}) is not within range {range:?}"); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| impl<E: Element, const A: usize> From<[E; A]> for TensorData { | ||
| fn from(elems: [E; A]) -> Self { | ||
| TensorData::new(elems.to_vec(), [A]) | ||
| } | ||
| } | ||
| impl<const A: usize> From<[usize; A]> for TensorData { | ||
| fn from(elems: [usize; A]) -> Self { | ||
| TensorData::new(elems.iter().map(|&e| e as i64).collect(), [A]) | ||
| } | ||
| } | ||
| impl From<&[usize]> for TensorData { | ||
| fn from(elems: &[usize]) -> Self { | ||
| let mut data = Vec::with_capacity(elems.len()); | ||
| for elem in elems.iter() { | ||
| data.push(*elem as i64); | ||
| } | ||
| TensorData::new(data, [elems.len()]) | ||
| } | ||
| } | ||
| impl<E: Element> From<&[E]> for TensorData { | ||
| fn from(elems: &[E]) -> Self { | ||
| let mut data = Vec::with_capacity(elems.len()); | ||
| for elem in elems.iter() { | ||
| data.push(*elem); | ||
| } | ||
| TensorData::new(data, [elems.len()]) | ||
| } | ||
| } | ||
| impl<E: Element, const A: usize, const B: usize> From<[[E; B]; A]> for TensorData { | ||
| fn from(elems: [[E; B]; A]) -> Self { | ||
| let mut data = Vec::with_capacity(A * B); | ||
| for elem in elems.into_iter().take(A) { | ||
| for elem in elem.into_iter().take(B) { | ||
| data.push(elem); | ||
| } | ||
| } | ||
| TensorData::new(data, [A, B]) | ||
| } | ||
| } | ||
| impl<E: Element, const A: usize, const B: usize, const C: usize> From<[[[E; C]; B]; A]> | ||
| for TensorData | ||
| { | ||
| fn from(elems: [[[E; C]; B]; A]) -> Self { | ||
| let mut data = Vec::with_capacity(A * B * C); | ||
| for elem in elems.into_iter().take(A) { | ||
| for elem in elem.into_iter().take(B) { | ||
| for elem in elem.into_iter().take(C) { | ||
| data.push(elem); | ||
| } | ||
| } | ||
| } | ||
| TensorData::new(data, [A, B, C]) | ||
| } | ||
| } | ||
| impl<E: Element, const A: usize, const B: usize, const C: usize, const D: usize> | ||
| From<[[[[E; D]; C]; B]; A]> for TensorData | ||
| { | ||
| fn from(elems: [[[[E; D]; C]; B]; A]) -> Self { | ||
| let mut data = Vec::with_capacity(A * B * C * D); | ||
| for elem in elems.into_iter().take(A) { | ||
| for elem in elem.into_iter().take(B) { | ||
| for elem in elem.into_iter().take(C) { | ||
| for elem in elem.into_iter().take(D) { | ||
| data.push(elem); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| TensorData::new(data, [A, B, C, D]) | ||
| } | ||
| } | ||
| impl<Elem: Element, const A: usize, const B: usize, const C: usize, const D: usize, const E: usize> | ||
| From<[[[[[Elem; E]; D]; C]; B]; A]> for TensorData | ||
| { | ||
| fn from(elems: [[[[[Elem; E]; D]; C]; B]; A]) -> Self { | ||
| let mut data = Vec::with_capacity(A * B * C * D * E); | ||
| for elem in elems.into_iter().take(A) { | ||
| for elem in elem.into_iter().take(B) { | ||
| for elem in elem.into_iter().take(C) { | ||
| for elem in elem.into_iter().take(D) { | ||
| for elem in elem.into_iter().take(E) { | ||
| data.push(elem); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| TensorData::new(data, [A, B, C, D, E]) | ||
| } | ||
| } | ||
| impl core::fmt::Display for TensorData { | ||
| fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { | ||
| let fmt = match self.dtype { | ||
| DType::F64 => format!("{:?}", self.as_slice::<f64>().unwrap()), | ||
| DType::F32 | DType::Flex32 => format!("{:?}", self.as_slice::<f32>().unwrap()), | ||
| DType::F16 => format!("{:?}", self.as_slice::<f16>().unwrap()), | ||
| DType::BF16 => format!("{:?}", self.as_slice::<bf16>().unwrap()), | ||
| DType::I64 => format!("{:?}", self.as_slice::<i64>().unwrap()), | ||
| DType::I32 => format!("{:?}", self.as_slice::<i32>().unwrap()), | ||
| DType::I16 => format!("{:?}", self.as_slice::<i16>().unwrap()), | ||
| DType::I8 => format!("{:?}", self.as_slice::<i8>().unwrap()), | ||
| DType::U64 => format!("{:?}", self.as_slice::<u64>().unwrap()), | ||
| DType::U32 => format!("{:?}", self.as_slice::<u32>().unwrap()), | ||
| DType::U16 => format!("{:?}", self.as_slice::<u16>().unwrap()), | ||
| DType::U8 => format!("{:?}", self.as_slice::<u8>().unwrap()), | ||
| DType::Bool => format!("{:?}", self.as_slice::<bool>().unwrap()), | ||
| DType::QFloat(scheme) => match scheme { | ||
| QuantScheme { | ||
| level: QuantLevel::Tensor | QuantLevel::Block(_), | ||
| mode: QuantMode::Symmetric, | ||
| value: | ||
| QuantValue::Q8F | ||
| | QuantValue::Q8S | ||
| // Display sub-byte values as i8 | ||
| | QuantValue::Q4F | ||
| | QuantValue::Q4S | ||
| | QuantValue::Q2F | ||
| | QuantValue::Q2S, | ||
| .. | ||
| } => { | ||
| format!("{:?} {scheme:?}", self.iter::<i8>().collect::<Vec<_>>()) | ||
| }, | ||
| QuantScheme { | ||
| level: QuantLevel::Tensor | QuantLevel::Block(_), | ||
| mode: QuantMode::Symmetric, | ||
| value: | ||
| QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1, | ||
| .. | ||
| } => { | ||
| unimplemented!("Can't format yet"); | ||
| } | ||
| }, | ||
| }; | ||
| f.write_str(fmt.as_str()) | ||
| } | ||
| } | ||
| /// The things that can go wrong when manipulating tensor data. | ||
| #[derive(Debug)] | ||
| pub enum DataError { | ||
| /// Failed to cast the values to a specified element type. | ||
| CastError(CheckedCastError), | ||
| /// Invalid target element type. | ||
| TypeMismatch(String), | ||
| } | ||
| impl core::error::Error for DataError {} | ||
| impl core::fmt::Display for DataError { | ||
| fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { | ||
| f.write_str(format!("{self:?}").as_str()) | ||
| } | ||
| } | ||
| /// The tolerance used to compare to floating point numbers. | ||
| /// | ||
| /// Generally, two numbers `x` and `y` are approximately equal if | ||
| /// | ||
| /// ```text | ||
| /// |x - y| < max(R * (|x + y|), A) | ||
| /// ``` | ||
| /// | ||
| /// where `R` is the relative tolerance and `A` is the absolute tolerance. | ||
| /// | ||
| /// | ||
| /// The most common way to initialize this struct is to use `Tolerance::<F>::default()`. | ||
| /// In that case, the relative and absolute tolerances are computed using an heuristic based | ||
| /// on the EPSILON and MIN_POSITIVE values of the given floating point type `F`. | ||
| /// | ||
| /// Another common initialization is `Tolerance::<F>::rel_abs(1e-4, 1e-5).set_half_precision_relative(1e-2)`. | ||
| /// This will use a sane default to manage values too close to 0.0 and | ||
| /// use different relative tolerances depending on the floating point precision. | ||
| #[derive(Debug, Clone, Copy)] | ||
| pub struct Tolerance<F> { | ||
| relative: F, | ||
| absolute: F, | ||
| } | ||
| impl<F: Float> Default for Tolerance<F> { | ||
| fn default() -> Self { | ||
| Self::balanced() | ||
| } | ||
| } | ||
| impl<F: Float> Tolerance<F> { | ||
| /// Create a tolerance with strict precision setting. | ||
| pub fn strict() -> Self { | ||
| Self { | ||
| relative: F::from(0.00).unwrap(), | ||
| absolute: F::from(64).unwrap() * F::min_positive_value(), | ||
| } | ||
| } | ||
| /// Create a tolerance with balanced precision setting. | ||
| pub fn balanced() -> Self { | ||
| Self { | ||
| relative: F::from(0.005).unwrap(), // 0.5% | ||
| absolute: F::from(1e-5).unwrap(), | ||
| } | ||
| } | ||
| /// Create a tolerance with permissive precision setting. | ||
| pub fn permissive() -> Self { | ||
| Self { | ||
| relative: F::from(0.01).unwrap(), // 1.0% | ||
| absolute: F::from(0.01).unwrap(), | ||
| } | ||
| } | ||
| /// When comparing two numbers, this uses both the relative and absolute differences. | ||
| /// | ||
| /// That is, `x` and `y` are approximately equal if | ||
| /// | ||
| /// ```text | ||
| /// |x - y| < max(R * (|x + y|), A) | ||
| /// ``` | ||
| /// | ||
| /// where `R` is the `relative` tolerance and `A` is the `absolute` tolerance. | ||
| pub fn rel_abs<FF: ToPrimitive>(relative: FF, absolute: FF) -> Self { | ||
| let relative = Self::check_relative(relative); | ||
| let absolute = Self::check_absolute(absolute); | ||
| Self { relative, absolute } | ||
| } | ||
| /// When comparing two numbers, this uses only the relative difference. | ||
| /// | ||
| /// That is, `x` and `y` are approximately equal if | ||
| /// | ||
| /// ```text | ||
| /// |x - y| < R * max(|x|, |y|) | ||
| /// ``` | ||
| /// | ||
| /// where `R` is the relative `tolerance`. | ||
| pub fn relative<FF: ToPrimitive>(tolerance: FF) -> Self { | ||
| let relative = Self::check_relative(tolerance); | ||
| Self { | ||
| relative, | ||
| absolute: F::from(0.0).unwrap(), | ||
| } | ||
| } | ||
| /// When comparing two numbers, this uses only the absolute difference. | ||
| /// | ||
| /// That is, `x` and `y` are approximately equal if | ||
| /// | ||
| /// ```text | ||
| /// |x - y| < A | ||
| /// ``` | ||
| /// | ||
| /// where `A` is the absolute `tolerance`. | ||
| pub fn absolute<FF: ToPrimitive>(tolerance: FF) -> Self { | ||
| let absolute = Self::check_absolute(tolerance); | ||
| Self { | ||
| relative: F::from(0.0).unwrap(), | ||
| absolute, | ||
| } | ||
| } | ||
| /// Change the relative tolerance to the given one. | ||
| pub fn set_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self { | ||
| self.relative = Self::check_relative(tolerance); | ||
| self | ||
| } | ||
| /// Change the relative tolerance to the given one only if `F` is half precision. | ||
| pub fn set_half_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self { | ||
| if core::mem::size_of::<F>() == 2 { | ||
| self.relative = Self::check_relative(tolerance); | ||
| } | ||
| self | ||
| } | ||
| /// Change the relative tolerance to the given one only if `F` is single precision. | ||
| pub fn set_single_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self { | ||
| if core::mem::size_of::<F>() == 4 { | ||
| self.relative = Self::check_relative(tolerance); | ||
| } | ||
| self | ||
| } | ||
| /// Change the relative tolerance to the given one only if `F` is double precision. | ||
| pub fn set_double_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self { | ||
| if core::mem::size_of::<F>() == 8 { | ||
| self.relative = Self::check_relative(tolerance); | ||
| } | ||
| self | ||
| } | ||
| /// Change the absolute tolerance to the given one. | ||
| pub fn set_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self { | ||
| self.absolute = Self::check_absolute(tolerance); | ||
| self | ||
| } | ||
| /// Change the absolute tolerance to the given one only if `F` is half precision. | ||
| pub fn set_half_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self { | ||
| if core::mem::size_of::<F>() == 2 { | ||
| self.absolute = Self::check_absolute(tolerance); | ||
| } | ||
| self | ||
| } | ||
| /// Change the absolute tolerance to the given one only if `F` is single precision. | ||
| pub fn set_single_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self { | ||
| if core::mem::size_of::<F>() == 4 { | ||
| self.absolute = Self::check_absolute(tolerance); | ||
| } | ||
| self | ||
| } | ||
| /// Change the absolute tolerance to the given one only if `F` is double precision. | ||
| pub fn set_double_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self { | ||
| if core::mem::size_of::<F>() == 8 { | ||
| self.absolute = Self::check_absolute(tolerance); | ||
| } | ||
| self | ||
| } | ||
| /// Checks if `x` and `y` are approximately equal given the tolerance. | ||
| pub fn approx_eq(&self, x: F, y: F) -> bool { | ||
| // See the accepted answer here | ||
| // https://stackoverflow.com/questions/4915462/how-should-i-do-floating-point-comparison | ||
| // This also handles the case where both a and b are infinity so that we don't need | ||
| // to manage it in the rest of the function. | ||
| if x == y { | ||
| return true; | ||
| } | ||
| let diff = (x - y).abs(); | ||
| let max = F::max(x.abs(), y.abs()); | ||
| diff < self.absolute.max(self.relative * max) | ||
| } | ||
| fn check_relative<FF: ToPrimitive>(tolerance: FF) -> F { | ||
| let tolerance = F::from(tolerance).unwrap(); | ||
| assert!(tolerance <= F::one()); | ||
| tolerance | ||
| } | ||
| fn check_absolute<FF: ToPrimitive>(tolerance: FF) -> F { | ||
| let tolerance = F::from(tolerance).unwrap(); | ||
| assert!(tolerance >= F::zero()); | ||
| tolerance | ||
| } | ||
| } | ||
| #[cfg(test)] | ||
| mod tests { | ||
| use crate::Shape; | ||
| use super::*; | ||
| use alloc::vec; | ||
| use rand::{SeedableRng, rngs::StdRng}; | ||
| #[test] | ||
| fn should_have_rank() { | ||
| let shape = Shape::new([3, 5, 6]); | ||
| let data = TensorData::random::<f32, _, _>( | ||
| shape, | ||
| Distribution::Default, | ||
| &mut StdRng::from_os_rng(), | ||
| ); | ||
| assert_eq!(data.rank(), 3); | ||
| } | ||
| #[test] | ||
| fn into_vec_should_yield_same_value_as_iter() { | ||
| let shape = Shape::new([3, 5, 6]); | ||
| let data = TensorData::random::<f32, _, _>( | ||
| shape, | ||
| Distribution::Default, | ||
| &mut StdRng::from_os_rng(), | ||
| ); | ||
| let expected = data.iter::<f32>().collect::<Vec<f32>>(); | ||
| let actual = data.into_vec::<f32>().unwrap(); | ||
| assert_eq!(expected, actual); | ||
| } | ||
| #[test] | ||
| #[should_panic] | ||
| fn into_vec_should_assert_wrong_dtype() { | ||
| let shape = Shape::new([3, 5, 6]); | ||
| let data = TensorData::random::<f32, _, _>( | ||
| shape, | ||
| Distribution::Default, | ||
| &mut StdRng::from_os_rng(), | ||
| ); | ||
| data.into_vec::<i32>().unwrap(); | ||
| } | ||
| #[test] | ||
| fn should_have_right_num_elements() { | ||
| let shape = Shape::new([3, 5, 6]); | ||
| let num_elements = shape.num_elements(); | ||
| let data = TensorData::random::<f32, _, _>( | ||
| shape, | ||
| Distribution::Default, | ||
| &mut StdRng::from_os_rng(), | ||
| ); | ||
| assert_eq!(num_elements, data.bytes.len() / 4); // f32 stored as u8s | ||
| assert_eq!(num_elements, data.as_slice::<f32>().unwrap().len()); | ||
| } | ||
| #[test] | ||
| fn should_have_right_shape() { | ||
| let data = TensorData::from([[3.0, 5.0, 6.0]]); | ||
| assert_eq!(data.shape, vec![1, 3]); | ||
| let data = TensorData::from([[4.0, 5.0, 8.0], [3.0, 5.0, 6.0]]); | ||
| assert_eq!(data.shape, vec![2, 3]); | ||
| let data = TensorData::from([3.0, 5.0, 6.0]); | ||
| assert_eq!(data.shape, vec![3]); | ||
| } | ||
| #[test] | ||
| fn should_assert_appox_eq_limit() { | ||
| let data1 = TensorData::from([[3.0, 5.0, 6.0]]); | ||
| let data2 = TensorData::from([[3.03, 5.0, 6.0]]); | ||
| data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(3e-2)); | ||
| data1.assert_approx_eq::<half::f16>(&data2, Tolerance::absolute(3e-2)); | ||
| } | ||
| #[test] | ||
| #[should_panic] | ||
| fn should_assert_approx_eq_above_limit() { | ||
| let data1 = TensorData::from([[3.0, 5.0, 6.0]]); | ||
| let data2 = TensorData::from([[3.031, 5.0, 6.0]]); | ||
| data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(1e-2)); | ||
| } | ||
| #[test] | ||
| #[should_panic] | ||
| fn should_assert_approx_eq_check_shape() { | ||
| let data1 = TensorData::from([[3.0, 5.0, 6.0, 7.0]]); | ||
| let data2 = TensorData::from([[3.0, 5.0, 6.0]]); | ||
| data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(1e-2)); | ||
| } | ||
| #[test] | ||
| fn should_convert_bytes_correctly() { | ||
| let mut vector: Vec<f32> = Vec::with_capacity(5); | ||
| vector.push(2.0); | ||
| vector.push(3.0); | ||
| let data1 = TensorData::new(vector, vec![2]); | ||
| let factor = core::mem::size_of::<f32>() / core::mem::size_of::<u8>(); | ||
| assert_eq!(data1.bytes.len(), 2 * factor); | ||
| assert_eq!(data1.bytes.capacity(), 5 * factor); | ||
| } | ||
| #[test] | ||
| fn should_convert_bytes_correctly_inplace() { | ||
| fn test_precision<E: Element>() { | ||
| let data = TensorData::new((0..32).collect(), [32]); | ||
| for (i, val) in data | ||
| .clone() | ||
| .convert::<E>() | ||
| .into_vec::<E>() | ||
| .unwrap() | ||
| .into_iter() | ||
| .enumerate() | ||
| { | ||
| assert_eq!(i as u32, val.elem::<u32>()) | ||
| } | ||
| } | ||
| test_precision::<f32>(); | ||
| test_precision::<f16>(); | ||
| test_precision::<i64>(); | ||
| test_precision::<i32>(); | ||
| } | ||
| macro_rules! test_dtypes { | ||
| ($test_name:ident, $($dtype:ty),*) => { | ||
| $( | ||
| paste::paste! { | ||
| #[test] | ||
| fn [<$test_name _ $dtype:snake>]() { | ||
| let full_dtype = TensorData::full_dtype([2, 16], 4, <$dtype>::dtype()); | ||
| let full = TensorData::full::<$dtype, _>([2, 16], 4.elem()); | ||
| assert_eq!(full_dtype, full); | ||
| } | ||
| } | ||
| )* | ||
| }; | ||
| } | ||
| test_dtypes!( | ||
| should_create_with_dtype, | ||
| bool, | ||
| i8, | ||
| i16, | ||
| i32, | ||
| i64, | ||
| u8, | ||
| u16, | ||
| u32, | ||
| u64, | ||
| f16, | ||
| bf16, | ||
| f32, | ||
| f64 | ||
| ); | ||
| } |
| use rand::{Rng, RngCore, distr::StandardUniform}; | ||
| use crate::{Element, ElementConversion}; | ||
| /// Distribution for random value of a tensor. | ||
| #[derive(Debug, Default, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] | ||
| pub enum Distribution { | ||
| /// Uniform distribution from 0 (inclusive) to 1 (exclusive). | ||
| #[default] | ||
| Default, | ||
| /// Bernoulli distribution with the given probability. | ||
| Bernoulli(f64), | ||
| /// Uniform distribution `[low, high)`. | ||
| Uniform(f64, f64), | ||
| /// Normal distribution with the given mean and standard deviation. | ||
| Normal(f64, f64), | ||
| } | ||
| /// Distribution sampler for random value of a tensor. | ||
| #[derive(new)] | ||
| pub struct DistributionSampler<'a, E, R> | ||
| where | ||
| StandardUniform: rand::distr::Distribution<E>, | ||
| E: rand::distr::uniform::SampleUniform, | ||
| R: RngCore, | ||
| { | ||
| kind: DistributionSamplerKind<E>, | ||
| rng: &'a mut R, | ||
| } | ||
| /// Distribution sampler kind for random value of a tensor. | ||
| pub enum DistributionSamplerKind<E> | ||
| where | ||
| StandardUniform: rand::distr::Distribution<E>, | ||
| E: rand::distr::uniform::SampleUniform, | ||
| { | ||
| /// Standard distribution. | ||
| Standard(rand::distr::StandardUniform), | ||
| /// Uniform distribution. | ||
| Uniform(rand::distr::Uniform<E>), | ||
| /// Bernoulli distribution. | ||
| Bernoulli(rand::distr::Bernoulli), | ||
| /// Normal distribution. | ||
| Normal(rand_distr::Normal<f64>), | ||
| } | ||
| impl<E, R> DistributionSampler<'_, E, R> | ||
| where | ||
| StandardUniform: rand::distr::Distribution<E>, | ||
| E: rand::distr::uniform::SampleUniform, | ||
| E: Element, | ||
| R: RngCore, | ||
| { | ||
| /// Sames a random value from the distribution. | ||
| pub fn sample(&mut self) -> E { | ||
| match &self.kind { | ||
| DistributionSamplerKind::Standard(distribution) => self.rng.sample(distribution), | ||
| DistributionSamplerKind::Uniform(distribution) => self.rng.sample(distribution), | ||
| DistributionSamplerKind::Bernoulli(distribution) => { | ||
| if self.rng.sample(distribution) { | ||
| 1.elem() | ||
| } else { | ||
| 0.elem() | ||
| } | ||
| } | ||
| DistributionSamplerKind::Normal(distribution) => self.rng.sample(distribution).elem(), | ||
| } | ||
| } | ||
| } | ||
| impl Distribution { | ||
| /// Creates a new distribution sampler. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `rng` - The random number generator. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The distribution sampler. | ||
| pub fn sampler<R, E>(self, rng: &'_ mut R) -> DistributionSampler<'_, E, R> | ||
| where | ||
| R: RngCore, | ||
| E: Element + rand::distr::uniform::SampleUniform, | ||
| StandardUniform: rand::distr::Distribution<E>, | ||
| { | ||
| let kind = match self { | ||
| Distribution::Default => { | ||
| DistributionSamplerKind::Standard(rand::distr::StandardUniform {}) | ||
| } | ||
| Distribution::Uniform(low, high) => DistributionSamplerKind::Uniform( | ||
| rand::distr::Uniform::new(low.elem::<E>(), high.elem::<E>()).unwrap(), | ||
| ), | ||
| Distribution::Bernoulli(prob) => { | ||
| DistributionSamplerKind::Bernoulli(rand::distr::Bernoulli::new(prob).unwrap()) | ||
| } | ||
| Distribution::Normal(mean, std) => { | ||
| DistributionSamplerKind::Normal(rand_distr::Normal::new(mean, std).unwrap()) | ||
| } | ||
| }; | ||
| DistributionSampler::new(kind, rng) | ||
| } | ||
| } | ||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
| #[test] | ||
| fn test_distribution_default() { | ||
| let dist: Distribution = Default::default(); | ||
| assert_eq!(dist, Distribution::Default); | ||
| assert_eq!(Distribution::default(), Distribution::Default); | ||
| } | ||
| } |
| use core::cmp::Ordering; | ||
| use crate::{Distribution, cast::ToElement, quantization::QuantScheme}; | ||
| #[cfg(feature = "cubecl")] | ||
| use cubecl::flex32; | ||
| use cubecl_quant::scheme::{QuantStore, QuantValue}; | ||
| use half::{bf16, f16}; | ||
| use rand::RngCore; | ||
| use serde::{Deserialize, Serialize}; | ||
| /// Element trait for tensor. | ||
| pub trait Element: | ||
| ToElement | ||
| + ElementRandom | ||
| + ElementConversion | ||
| + ElementComparison | ||
| + ElementLimits | ||
| + bytemuck::CheckedBitPattern | ||
| + bytemuck::NoUninit | ||
| + bytemuck::Zeroable | ||
| + core::fmt::Debug | ||
| + core::fmt::Display | ||
| + Default | ||
| + Send | ||
| + Sync | ||
| + Copy | ||
| + 'static | ||
| { | ||
| /// The dtype of the element. | ||
| fn dtype() -> DType; | ||
| } | ||
| /// Element conversion trait for tensor. | ||
| pub trait ElementConversion { | ||
| /// Converts an element to another element. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `elem` - The element to convert. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The converted element. | ||
| fn from_elem<E: ToElement>(elem: E) -> Self; | ||
| /// Converts and returns the converted element. | ||
| fn elem<E: Element>(self) -> E; | ||
| } | ||
| /// Element trait for random value of a tensor. | ||
| pub trait ElementRandom { | ||
| /// Returns a random value for the given distribution. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `distribution` - The distribution to sample from. | ||
| /// * `rng` - The random number generator. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// The random value. | ||
| fn random<R: RngCore>(distribution: Distribution, rng: &mut R) -> Self; | ||
| } | ||
| /// Element ordering trait. | ||
| pub trait ElementComparison { | ||
| /// Returns and [Ordering] between `self` and `other`. | ||
| fn cmp(&self, other: &Self) -> Ordering; | ||
| } | ||
| /// Element ordering trait. | ||
| pub trait ElementLimits { | ||
| /// The minimum representable value | ||
| const MIN: Self; | ||
| /// The maximum representable value | ||
| const MAX: Self; | ||
| } | ||
| /// Macro to implement the element trait for a type. | ||
| #[macro_export] | ||
| macro_rules! make_element { | ||
| ( | ||
| ty $type:ident, | ||
| convert $convert:expr, | ||
| random $random:expr, | ||
| cmp $cmp:expr, | ||
| dtype $dtype:expr | ||
| ) => { | ||
| make_element!(ty $type, convert $convert, random $random, cmp $cmp, dtype $dtype, min $type::MIN, max $type::MAX); | ||
| }; | ||
| ( | ||
| ty $type:ident, | ||
| convert $convert:expr, | ||
| random $random:expr, | ||
| cmp $cmp:expr, | ||
| dtype $dtype:expr, | ||
| min $min:expr, | ||
| max $max:expr | ||
| ) => { | ||
| impl Element for $type { | ||
| #[inline(always)] | ||
| fn dtype() -> $crate::DType { | ||
| $dtype | ||
| } | ||
| } | ||
| impl ElementConversion for $type { | ||
| #[inline(always)] | ||
| fn from_elem<E: ToElement>(elem: E) -> Self { | ||
| #[allow(clippy::redundant_closure_call)] | ||
| $convert(&elem) | ||
| } | ||
| #[inline(always)] | ||
| fn elem<E: Element>(self) -> E { | ||
| E::from_elem(self) | ||
| } | ||
| } | ||
| impl ElementRandom for $type { | ||
| fn random<R: RngCore>(distribution: Distribution, rng: &mut R) -> Self { | ||
| #[allow(clippy::redundant_closure_call)] | ||
| $random(distribution, rng) | ||
| } | ||
| } | ||
| impl ElementComparison for $type { | ||
| fn cmp(&self, other: &Self) -> Ordering { | ||
| let a = self.elem::<$type>(); | ||
| let b = other.elem::<$type>(); | ||
| #[allow(clippy::redundant_closure_call)] | ||
| $cmp(&a, &b) | ||
| } | ||
| } | ||
| impl ElementLimits for $type { | ||
| const MIN: Self = $min; | ||
| const MAX: Self = $max; | ||
| } | ||
| }; | ||
| } | ||
| make_element!( | ||
| ty f64, | ||
| convert ToElement::to_f64, | ||
| random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), | ||
| cmp |a: &f64, b: &f64| a.total_cmp(b), | ||
| dtype DType::F64 | ||
| ); | ||
| make_element!( | ||
| ty f32, | ||
| convert ToElement::to_f32, | ||
| random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), | ||
| cmp |a: &f32, b: &f32| a.total_cmp(b), | ||
| dtype DType::F32 | ||
| ); | ||
| make_element!( | ||
| ty i64, | ||
| convert ToElement::to_i64, | ||
| random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), | ||
| cmp |a: &i64, b: &i64| Ord::cmp(a, b), | ||
| dtype DType::I64 | ||
| ); | ||
| make_element!( | ||
| ty u64, | ||
| convert ToElement::to_u64, | ||
| random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), | ||
| cmp |a: &u64, b: &u64| Ord::cmp(a, b), | ||
| dtype DType::U64 | ||
| ); | ||
| make_element!( | ||
| ty i32, | ||
| convert ToElement::to_i32, | ||
| random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), | ||
| cmp |a: &i32, b: &i32| Ord::cmp(a, b), | ||
| dtype DType::I32 | ||
| ); | ||
| make_element!( | ||
| ty u32, | ||
| convert ToElement::to_u32, | ||
| random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), | ||
| cmp |a: &u32, b: &u32| Ord::cmp(a, b), | ||
| dtype DType::U32 | ||
| ); | ||
| make_element!( | ||
| ty i16, | ||
| convert ToElement::to_i16, | ||
| random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), | ||
| cmp |a: &i16, b: &i16| Ord::cmp(a, b), | ||
| dtype DType::I16 | ||
| ); | ||
| make_element!( | ||
| ty u16, | ||
| convert ToElement::to_u16, | ||
| random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), | ||
| cmp |a: &u16, b: &u16| Ord::cmp(a, b), | ||
| dtype DType::U16 | ||
| ); | ||
| make_element!( | ||
| ty i8, | ||
| convert ToElement::to_i8, | ||
| random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), | ||
| cmp |a: &i8, b: &i8| Ord::cmp(a, b), | ||
| dtype DType::I8 | ||
| ); | ||
| make_element!( | ||
| ty u8, | ||
| convert ToElement::to_u8, | ||
| random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), | ||
| cmp |a: &u8, b: &u8| Ord::cmp(a, b), | ||
| dtype DType::U8 | ||
| ); | ||
| make_element!( | ||
| ty f16, | ||
| convert ToElement::to_f16, | ||
| random |distribution: Distribution, rng: &mut R| { | ||
| let sample: f32 = distribution.sampler(rng).sample(); | ||
| f16::from_elem(sample) | ||
| }, | ||
| cmp |a: &f16, b: &f16| a.total_cmp(b), | ||
| dtype DType::F16 | ||
| ); | ||
| make_element!( | ||
| ty bf16, | ||
| convert ToElement::to_bf16, | ||
| random |distribution: Distribution, rng: &mut R| { | ||
| let sample: f32 = distribution.sampler(rng).sample(); | ||
| bf16::from_elem(sample) | ||
| }, | ||
| cmp |a: &bf16, b: &bf16| a.total_cmp(b), | ||
| dtype DType::BF16 | ||
| ); | ||
| #[cfg(feature = "cubecl")] | ||
| make_element!( | ||
| ty flex32, | ||
| convert |elem: &dyn ToElement| flex32::from_f32(elem.to_f32()), | ||
| random |distribution: Distribution, rng: &mut R| { | ||
| let sample: f32 = distribution.sampler(rng).sample(); | ||
| flex32::from_elem(sample) | ||
| }, | ||
| cmp |a: &flex32, b: &flex32| a.total_cmp(b), | ||
| dtype DType::Flex32, | ||
| min flex32::from_f32(half::f16::MIN.to_f32_const()), | ||
| max flex32::from_f32(half::f16::MAX.to_f32_const()) | ||
| ); | ||
| make_element!( | ||
| ty bool, | ||
| convert ToElement::to_bool, | ||
| random |distribution: Distribution, rng: &mut R| { | ||
| let sample: u8 = distribution.sampler(rng).sample(); | ||
| bool::from_elem(sample) | ||
| }, | ||
| cmp |a: &bool, b: &bool| Ord::cmp(a, b), | ||
| dtype DType::Bool, | ||
| min false, | ||
| max true | ||
| ); | ||
| #[allow(missing_docs)] | ||
| #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)] | ||
| pub enum DType { | ||
| F64, | ||
| F32, | ||
| Flex32, | ||
| F16, | ||
| BF16, | ||
| I64, | ||
| I32, | ||
| I16, | ||
| I8, | ||
| U64, | ||
| U32, | ||
| U16, | ||
| U8, | ||
| Bool, | ||
| QFloat(QuantScheme), | ||
| } | ||
| #[cfg(feature = "cubecl")] | ||
| impl From<cubecl::ir::ElemType> for DType { | ||
| fn from(value: cubecl::ir::ElemType) -> Self { | ||
| match value { | ||
| cubecl::ir::ElemType::Float(float_kind) => match float_kind { | ||
| cubecl::ir::FloatKind::F16 => DType::F16, | ||
| cubecl::ir::FloatKind::BF16 => DType::BF16, | ||
| cubecl::ir::FloatKind::Flex32 => DType::Flex32, | ||
| cubecl::ir::FloatKind::F32 => DType::F32, | ||
| cubecl::ir::FloatKind::F64 => DType::F64, | ||
| cubecl::ir::FloatKind::TF32 => panic!("Not a valid DType for tensors."), | ||
| cubecl::ir::FloatKind::E2M1 | ||
| | cubecl::ir::FloatKind::E2M3 | ||
| | cubecl::ir::FloatKind::E3M2 | ||
| | cubecl::ir::FloatKind::E4M3 | ||
| | cubecl::ir::FloatKind::E5M2 | ||
| | cubecl::ir::FloatKind::UE8M0 => { | ||
| unimplemented!("Not yet supported, will be used for quantization") | ||
| } | ||
| }, | ||
| cubecl::ir::ElemType::Int(int_kind) => match int_kind { | ||
| cubecl::ir::IntKind::I8 => DType::I8, | ||
| cubecl::ir::IntKind::I16 => DType::I16, | ||
| cubecl::ir::IntKind::I32 => DType::I32, | ||
| cubecl::ir::IntKind::I64 => DType::I64, | ||
| }, | ||
| cubecl::ir::ElemType::UInt(uint_kind) => match uint_kind { | ||
| cubecl::ir::UIntKind::U8 => DType::U8, | ||
| cubecl::ir::UIntKind::U16 => DType::U16, | ||
| cubecl::ir::UIntKind::U32 => DType::U32, | ||
| cubecl::ir::UIntKind::U64 => DType::U64, | ||
| }, | ||
| _ => panic!("Not a valid DType for tensors."), | ||
| } | ||
| } | ||
| } | ||
| impl DType { | ||
| /// Returns the size of a type in bytes. | ||
| pub const fn size(&self) -> usize { | ||
| match self { | ||
| DType::F64 => core::mem::size_of::<f64>(), | ||
| DType::F32 => core::mem::size_of::<f32>(), | ||
| DType::Flex32 => core::mem::size_of::<f32>(), | ||
| DType::F16 => core::mem::size_of::<f16>(), | ||
| DType::BF16 => core::mem::size_of::<bf16>(), | ||
| DType::I64 => core::mem::size_of::<i64>(), | ||
| DType::I32 => core::mem::size_of::<i32>(), | ||
| DType::I16 => core::mem::size_of::<i16>(), | ||
| DType::I8 => core::mem::size_of::<i8>(), | ||
| DType::U64 => core::mem::size_of::<u64>(), | ||
| DType::U32 => core::mem::size_of::<u32>(), | ||
| DType::U16 => core::mem::size_of::<u16>(), | ||
| DType::U8 => core::mem::size_of::<u8>(), | ||
| DType::Bool => core::mem::size_of::<bool>(), | ||
| DType::QFloat(scheme) => match scheme.store { | ||
| QuantStore::Native => match scheme.value { | ||
| QuantValue::Q8F | QuantValue::Q8S => core::mem::size_of::<i8>(), | ||
| // e2m1 native is automatically packed by the kernels, so the actual storage is | ||
| // 8 bits wide. | ||
| QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => { | ||
| core::mem::size_of::<u8>() | ||
| } | ||
| QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => { | ||
| // Sub-byte values have fractional size | ||
| 0 | ||
| } | ||
| }, | ||
| QuantStore::U32 => core::mem::size_of::<u32>(), | ||
| }, | ||
| } | ||
| } | ||
| /// Returns true if the data type is a floating point type. | ||
| pub fn is_float(&self) -> bool { | ||
| matches!( | ||
| self, | ||
| DType::F64 | DType::F32 | DType::Flex32 | DType::F16 | DType::BF16 | ||
| ) | ||
| } | ||
| /// Returns true if the data type is a signed integer type. | ||
| pub fn is_int(&self) -> bool { | ||
| matches!(self, DType::I64 | DType::I32 | DType::I16 | DType::I8) | ||
| } | ||
| /// Returns true if the data type is an unsigned integer type. | ||
| pub fn is_uint(&self) -> bool { | ||
| matches!(self, DType::U64 | DType::U32 | DType::U16 | DType::U8) | ||
| } | ||
| /// Returns true if the data type is a boolean type | ||
| pub fn is_bool(&self) -> bool { | ||
| matches!(self, DType::Bool) | ||
| } | ||
| /// Returns the data type name. | ||
| pub fn name(&self) -> &'static str { | ||
| match self { | ||
| DType::F64 => "f64", | ||
| DType::F32 => "f32", | ||
| DType::Flex32 => "flex32", | ||
| DType::F16 => "f16", | ||
| DType::BF16 => "bf16", | ||
| DType::I64 => "i64", | ||
| DType::I32 => "i32", | ||
| DType::I16 => "i16", | ||
| DType::I8 => "i8", | ||
| DType::U64 => "u64", | ||
| DType::U32 => "u32", | ||
| DType::U16 => "u16", | ||
| DType::U8 => "u8", | ||
| DType::Bool => "bool", | ||
| DType::QFloat(_) => "qfloat", | ||
| } | ||
| } | ||
| } | ||
| #[allow(missing_docs)] | ||
| #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] | ||
| pub enum FloatDType { | ||
| F64, | ||
| F32, | ||
| Flex32, | ||
| F16, | ||
| BF16, | ||
| } | ||
| impl From<DType> for FloatDType { | ||
| fn from(value: DType) -> Self { | ||
| match value { | ||
| DType::F64 => FloatDType::F64, | ||
| DType::F32 => FloatDType::F32, | ||
| DType::Flex32 => FloatDType::Flex32, | ||
| DType::F16 => FloatDType::F16, | ||
| DType::BF16 => FloatDType::BF16, | ||
| _ => panic!("Expected float data type, got {value:?}"), | ||
| } | ||
| } | ||
| } | ||
| impl From<FloatDType> for DType { | ||
| fn from(value: FloatDType) -> Self { | ||
| match value { | ||
| FloatDType::F64 => DType::F64, | ||
| FloatDType::F32 => DType::F32, | ||
| FloatDType::Flex32 => DType::Flex32, | ||
| FloatDType::F16 => DType::F16, | ||
| FloatDType::BF16 => DType::BF16, | ||
| } | ||
| } | ||
| } | ||
| #[allow(missing_docs)] | ||
| #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] | ||
| pub enum IntDType { | ||
| I64, | ||
| I32, | ||
| I16, | ||
| I8, | ||
| U64, | ||
| U32, | ||
| U16, | ||
| U8, | ||
| } | ||
| impl From<DType> for IntDType { | ||
| fn from(value: DType) -> Self { | ||
| match value { | ||
| DType::I64 => IntDType::I64, | ||
| DType::I32 => IntDType::I32, | ||
| DType::I16 => IntDType::I16, | ||
| DType::I8 => IntDType::I8, | ||
| DType::U64 => IntDType::U64, | ||
| DType::U32 => IntDType::U32, | ||
| DType::U16 => IntDType::U16, | ||
| DType::U8 => IntDType::U8, | ||
| _ => panic!("Expected int data type, got {value:?}"), | ||
| } | ||
| } | ||
| } | ||
| impl From<IntDType> for DType { | ||
| fn from(value: IntDType) -> Self { | ||
| match value { | ||
| IntDType::I64 => DType::I64, | ||
| IntDType::I32 => DType::I32, | ||
| IntDType::I16 => DType::I16, | ||
| IntDType::I8 => DType::I8, | ||
| IntDType::U64 => DType::U64, | ||
| IntDType::U32 => DType::U32, | ||
| IntDType::U16 => DType::U16, | ||
| IntDType::U8 => DType::U8, | ||
| } | ||
| } | ||
| } |
| use core::mem::size_of; | ||
| use half::{bf16, f16}; | ||
| /// A generic trait for converting a value to a number. | ||
| /// Adapted from [num_traits::ToPrimitive] to support [bool]. | ||
| /// | ||
| /// A value can be represented by the target type when it lies within | ||
| /// the range of scalars supported by the target type. | ||
| /// For example, a negative integer cannot be represented by an unsigned | ||
| /// integer type, and an `i64` with a very high magnitude might not be | ||
| /// convertible to an `i32`. | ||
| /// On the other hand, conversions with possible precision loss or truncation | ||
| /// are admitted, like an `f32` with a decimal part to an integer type, or | ||
| /// even a large `f64` saturating to `f32` infinity. | ||
| /// | ||
| /// The methods *panic* when the value cannot be represented by the target type. | ||
| pub trait ToElement { | ||
| /// Converts the value of `self` to an `isize`. | ||
| #[inline] | ||
| fn to_isize(&self) -> isize { | ||
| ToElement::to_isize(&self.to_i64()) | ||
| } | ||
| /// Converts the value of `self` to an `i8`. | ||
| #[inline] | ||
| fn to_i8(&self) -> i8 { | ||
| ToElement::to_i8(&self.to_i64()) | ||
| } | ||
| /// Converts the value of `self` to an `i16`. | ||
| #[inline] | ||
| fn to_i16(&self) -> i16 { | ||
| ToElement::to_i16(&self.to_i64()) | ||
| } | ||
| /// Converts the value of `self` to an `i32`. | ||
| #[inline] | ||
| fn to_i32(&self) -> i32 { | ||
| ToElement::to_i32(&self.to_i64()) | ||
| } | ||
| /// Converts the value of `self` to an `i64`. | ||
| fn to_i64(&self) -> i64; | ||
| /// Converts the value of `self` to an `i128`. | ||
| /// | ||
| /// The default implementation converts through `to_i64()`. Types implementing | ||
| /// this trait should override this method if they can represent a greater range. | ||
| #[inline] | ||
| fn to_i128(&self) -> i128 { | ||
| i128::from(self.to_i64()) | ||
| } | ||
| /// Converts the value of `self` to a `usize`. | ||
| #[inline] | ||
| fn to_usize(&self) -> usize { | ||
| ToElement::to_usize(&self.to_u64()) | ||
| } | ||
| /// Converts the value of `self` to a `u8`. | ||
| #[inline] | ||
| fn to_u8(&self) -> u8 { | ||
| ToElement::to_u8(&self.to_u64()) | ||
| } | ||
| /// Converts the value of `self` to a `u16`. | ||
| #[inline] | ||
| fn to_u16(&self) -> u16 { | ||
| ToElement::to_u16(&self.to_u64()) | ||
| } | ||
| /// Converts the value of `self` to a `u32`. | ||
| #[inline] | ||
| fn to_u32(&self) -> u32 { | ||
| ToElement::to_u32(&self.to_u64()) | ||
| } | ||
| /// Converts the value of `self` to a `u64`. | ||
| fn to_u64(&self) -> u64; | ||
| /// Converts the value of `self` to a `u128`. | ||
| /// | ||
| /// The default implementation converts through `to_u64()`. Types implementing | ||
| /// this trait should override this method if they can represent a greater range. | ||
| #[inline] | ||
| fn to_u128(&self) -> u128 { | ||
| u128::from(self.to_u64()) | ||
| } | ||
| /// Converts the value of `self` to an `f16`. Overflows may map to positive | ||
| /// or negative infinity. | ||
| #[inline] | ||
| fn to_f16(&self) -> f16 { | ||
| f16::from_f32(self.to_f32()) | ||
| } | ||
| /// Converts the value of `self` to an `bf16`. Overflows may map to positive | ||
| /// or negative infinity. | ||
| #[inline] | ||
| fn to_bf16(&self) -> bf16 { | ||
| bf16::from_f32(self.to_f32()) | ||
| } | ||
| /// Converts the value of `self` to an `f32`. Overflows may map to positive | ||
| /// or negative infinity. | ||
| #[inline] | ||
| fn to_f32(&self) -> f32 { | ||
| ToElement::to_f32(&self.to_f64()) | ||
| } | ||
| /// Converts the value of `self` to an `f64`. Overflows may map to positive | ||
| /// or negative infinity. | ||
| /// | ||
| /// The default implementation tries to convert through `to_i64()`, and | ||
| /// failing that through `to_u64()`. Types implementing this trait should | ||
| /// override this method if they can represent a greater range. | ||
| #[inline] | ||
| fn to_f64(&self) -> f64 { | ||
| ToElement::to_f64(&self.to_u64()) | ||
| } | ||
| /// Converts the value of `self` to a bool. | ||
| /// Rust only considers 0 and 1 to be valid booleans, but for compatibility, C semantics are | ||
| /// adopted (anything that's not 0 is true). | ||
| /// | ||
| /// The default implementation tries to convert through `to_i64()`, and | ||
| /// failing that through `to_u64()`. Types implementing this trait should | ||
| /// override this method if they can represent a greater range. | ||
| #[inline] | ||
| fn to_bool(&self) -> bool { | ||
| ToElement::to_bool(&self.to_u64()) | ||
| } | ||
| } | ||
| macro_rules! impl_to_element_int_to_int { | ||
| ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$( | ||
| #[inline] | ||
| $(#[$cfg])* | ||
| fn $method(&self) -> $DstT { | ||
| let min = $DstT::MIN as $SrcT; | ||
| let max = $DstT::MAX as $SrcT; | ||
| if size_of::<$SrcT>() <= size_of::<$DstT>() || (min <= *self && *self <= max) { | ||
| *self as $DstT | ||
| } else { | ||
| panic!( | ||
| "Element cannot be represented in the target type: {:?}({:?}) => {:?}", | ||
| core::any::type_name::<$SrcT>(), | ||
| self, | ||
| core::any::type_name::<$DstT>(), | ||
| ) | ||
| } | ||
| } | ||
| )*} | ||
| } | ||
| macro_rules! impl_to_element_int_to_uint { | ||
| ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$( | ||
| #[inline] | ||
| $(#[$cfg])* | ||
| fn $method(&self) -> $DstT { | ||
| let max = $DstT::MAX as $SrcT; | ||
| if 0 <= *self && (size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max) { | ||
| *self as $DstT | ||
| } else { | ||
| panic!( | ||
| "Element cannot be represented in the target type: {:?}({:?}) => {:?}", | ||
| core::any::type_name::<$SrcT>(), | ||
| self, | ||
| core::any::type_name::<$DstT>(), | ||
| ) | ||
| } | ||
| } | ||
| )*} | ||
| } | ||
| macro_rules! impl_to_element_int { | ||
| ($T:ident) => { | ||
| impl ToElement for $T { | ||
| impl_to_element_int_to_int! { $T: | ||
| fn to_isize -> isize; | ||
| fn to_i8 -> i8; | ||
| fn to_i16 -> i16; | ||
| fn to_i32 -> i32; | ||
| fn to_i64 -> i64; | ||
| fn to_i128 -> i128; | ||
| } | ||
| impl_to_element_int_to_uint! { $T: | ||
| fn to_usize -> usize; | ||
| fn to_u8 -> u8; | ||
| fn to_u16 -> u16; | ||
| fn to_u32 -> u32; | ||
| fn to_u64 -> u64; | ||
| fn to_u128 -> u128; | ||
| } | ||
| #[inline] | ||
| fn to_f32(&self) -> f32 { | ||
| *self as f32 | ||
| } | ||
| #[inline] | ||
| fn to_f64(&self) -> f64 { | ||
| *self as f64 | ||
| } | ||
| #[inline] | ||
| fn to_bool(&self) -> bool { | ||
| *self != 0 | ||
| } | ||
| } | ||
| }; | ||
| } | ||
| impl_to_element_int!(isize); | ||
| impl_to_element_int!(i8); | ||
| impl_to_element_int!(i16); | ||
| impl_to_element_int!(i32); | ||
| impl_to_element_int!(i64); | ||
| impl_to_element_int!(i128); | ||
| macro_rules! impl_to_element_uint_to_int { | ||
| ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$( | ||
| #[inline] | ||
| $(#[$cfg])* | ||
| fn $method(&self) -> $DstT { | ||
| let max = $DstT::MAX as $SrcT; | ||
| if size_of::<$SrcT>() < size_of::<$DstT>() || *self <= max { | ||
| *self as $DstT | ||
| } else { | ||
| panic!( | ||
| "Element cannot be represented in the target type: {:?}({:?}) => {:?}", | ||
| core::any::type_name::<$SrcT>(), | ||
| self, | ||
| core::any::type_name::<$DstT>(), | ||
| ) | ||
| } | ||
| } | ||
| )*} | ||
| } | ||
| macro_rules! impl_to_element_uint_to_uint { | ||
| ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$( | ||
| #[inline] | ||
| $(#[$cfg])* | ||
| fn $method(&self) -> $DstT { | ||
| let max = $DstT::MAX as $SrcT; | ||
| if size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max { | ||
| *self as $DstT | ||
| } else { | ||
| panic!( | ||
| "Element cannot be represented in the target type: {:?}({:?}) => {:?}", | ||
| core::any::type_name::<$SrcT>(), | ||
| self, | ||
| core::any::type_name::<$DstT>(), | ||
| ) | ||
| } | ||
| } | ||
| )*} | ||
| } | ||
| macro_rules! impl_to_element_uint { | ||
| ($T:ident) => { | ||
| impl ToElement for $T { | ||
| impl_to_element_uint_to_int! { $T: | ||
| fn to_isize -> isize; | ||
| fn to_i8 -> i8; | ||
| fn to_i16 -> i16; | ||
| fn to_i32 -> i32; | ||
| fn to_i64 -> i64; | ||
| fn to_i128 -> i128; | ||
| } | ||
| impl_to_element_uint_to_uint! { $T: | ||
| fn to_usize -> usize; | ||
| fn to_u8 -> u8; | ||
| fn to_u16 -> u16; | ||
| fn to_u32 -> u32; | ||
| fn to_u64 -> u64; | ||
| fn to_u128 -> u128; | ||
| } | ||
| #[inline] | ||
| fn to_f32(&self) -> f32 { | ||
| *self as f32 | ||
| } | ||
| #[inline] | ||
| fn to_f64(&self) -> f64 { | ||
| *self as f64 | ||
| } | ||
| #[inline] | ||
| fn to_bool(&self) -> bool { | ||
| *self != 0 | ||
| } | ||
| } | ||
| }; | ||
| } | ||
| impl_to_element_uint!(usize); | ||
| impl_to_element_uint!(u8); | ||
| impl_to_element_uint!(u16); | ||
| impl_to_element_uint!(u32); | ||
| impl_to_element_uint!(u64); | ||
| impl_to_element_uint!(u128); | ||
| macro_rules! impl_to_element_float_to_float { | ||
| ($SrcT:ident : $( fn $method:ident -> $DstT:ident ; )*) => {$( | ||
| #[inline] | ||
| fn $method(&self) -> $DstT { | ||
| // We can safely cast all values, whether NaN, +-inf, or finite. | ||
| // Finite values that are reducing size may saturate to +-inf. | ||
| *self as $DstT | ||
| } | ||
| )*} | ||
| } | ||
| macro_rules! float_to_int_unchecked { | ||
| // SAFETY: Must not be NaN or infinite; must be representable as the integer after truncating. | ||
| // We already checked that the float is in the exclusive range `(MIN-1, MAX+1)`. | ||
| ($float:expr => $int:ty) => { | ||
| unsafe { $float.to_int_unchecked::<$int>() } | ||
| }; | ||
| } | ||
| macro_rules! impl_to_element_float_to_signed_int { | ||
| ($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $i:ident ; )*) => {$( | ||
| #[inline] | ||
| $(#[$cfg])* | ||
| fn $method(&self) -> $i { | ||
| // Float as int truncates toward zero, so we want to allow values | ||
| // in the exclusive range `(MIN-1, MAX+1)`. | ||
| if size_of::<$f>() > size_of::<$i>() { | ||
| // With a larger size, we can represent the range exactly. | ||
| const MIN_M1: $f = $i::MIN as $f - 1.0; | ||
| const MAX_P1: $f = $i::MAX as $f + 1.0; | ||
| if *self > MIN_M1 && *self < MAX_P1 { | ||
| return float_to_int_unchecked!(*self => $i); | ||
| } | ||
| } else { | ||
| // We can't represent `MIN-1` exactly, but there's no fractional part | ||
| // at this magnitude, so we can just use a `MIN` inclusive boundary. | ||
| const MIN: $f = $i::MIN as $f; | ||
| // We can't represent `MAX` exactly, but it will round up to exactly | ||
| // `MAX+1` (a power of two) when we cast it. | ||
| const MAX_P1: $f = $i::MAX as $f; | ||
| if *self >= MIN && *self < MAX_P1 { | ||
| return float_to_int_unchecked!(*self => $i); | ||
| } | ||
| } | ||
| panic!("Float cannot be represented in the target signed int type") | ||
| } | ||
| )*} | ||
| } | ||
| macro_rules! impl_to_element_float_to_unsigned_int { | ||
| ($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $u:ident ; )*) => {$( | ||
| #[inline] | ||
| $(#[$cfg])* | ||
| fn $method(&self) -> $u { | ||
| // Float as int truncates toward zero, so we want to allow values | ||
| // in the exclusive range `(-1, MAX+1)`. | ||
| if size_of::<$f>() > size_of::<$u>() { | ||
| // With a larger size, we can represent the range exactly. | ||
| const MAX_P1: $f = $u::MAX as $f + 1.0; | ||
| if *self > -1.0 && *self < MAX_P1 { | ||
| return float_to_int_unchecked!(*self => $u); | ||
| } | ||
| } else { | ||
| // We can't represent `MAX` exactly, but it will round up to exactly | ||
| // `MAX+1` (a power of two) when we cast it. | ||
| // (`u128::MAX as f32` is infinity, but this is still ok.) | ||
| const MAX_P1: $f = $u::MAX as $f; | ||
| if *self > -1.0 && *self < MAX_P1 { | ||
| return float_to_int_unchecked!(*self => $u); | ||
| } | ||
| } | ||
| panic!("Float cannot be represented in the target unsigned int type") | ||
| } | ||
| )*} | ||
| } | ||
| macro_rules! impl_to_element_float { | ||
| ($T:ident) => { | ||
| impl ToElement for $T { | ||
| impl_to_element_float_to_signed_int! { $T: | ||
| fn to_isize -> isize; | ||
| fn to_i8 -> i8; | ||
| fn to_i16 -> i16; | ||
| fn to_i32 -> i32; | ||
| fn to_i64 -> i64; | ||
| fn to_i128 -> i128; | ||
| } | ||
| impl_to_element_float_to_unsigned_int! { $T: | ||
| fn to_usize -> usize; | ||
| fn to_u8 -> u8; | ||
| fn to_u16 -> u16; | ||
| fn to_u32 -> u32; | ||
| fn to_u64 -> u64; | ||
| fn to_u128 -> u128; | ||
| } | ||
| impl_to_element_float_to_float! { $T: | ||
| fn to_f32 -> f32; | ||
| fn to_f64 -> f64; | ||
| } | ||
| #[inline] | ||
| fn to_bool(&self) -> bool { | ||
| *self != 0.0 | ||
| } | ||
| } | ||
| }; | ||
| } | ||
| impl_to_element_float!(f32); | ||
| impl_to_element_float!(f64); | ||
| impl ToElement for f16 { | ||
| #[inline] | ||
| fn to_i64(&self) -> i64 { | ||
| Self::to_f32(*self).to_i64() | ||
| } | ||
| #[inline] | ||
| fn to_u64(&self) -> u64 { | ||
| Self::to_f32(*self).to_u64() | ||
| } | ||
| #[inline] | ||
| fn to_i8(&self) -> i8 { | ||
| Self::to_f32(*self).to_i8() | ||
| } | ||
| #[inline] | ||
| fn to_u8(&self) -> u8 { | ||
| Self::to_f32(*self).to_u8() | ||
| } | ||
| #[inline] | ||
| fn to_i16(&self) -> i16 { | ||
| Self::to_f32(*self).to_i16() | ||
| } | ||
| #[inline] | ||
| fn to_u16(&self) -> u16 { | ||
| Self::to_f32(*self).to_u16() | ||
| } | ||
| #[inline] | ||
| fn to_i32(&self) -> i32 { | ||
| Self::to_f32(*self).to_i32() | ||
| } | ||
| #[inline] | ||
| fn to_u32(&self) -> u32 { | ||
| Self::to_f32(*self).to_u32() | ||
| } | ||
| #[inline] | ||
| fn to_f16(&self) -> f16 { | ||
| *self | ||
| } | ||
| #[inline] | ||
| fn to_f32(&self) -> f32 { | ||
| Self::to_f32(*self) | ||
| } | ||
| #[inline] | ||
| fn to_f64(&self) -> f64 { | ||
| Self::to_f64(*self) | ||
| } | ||
| #[inline] | ||
| fn to_bool(&self) -> bool { | ||
| *self != f16::from_f32_const(0.0) | ||
| } | ||
| } | ||
| impl ToElement for bf16 { | ||
| #[inline] | ||
| fn to_i64(&self) -> i64 { | ||
| Self::to_f32(*self).to_i64() | ||
| } | ||
| #[inline] | ||
| fn to_u64(&self) -> u64 { | ||
| Self::to_f32(*self).to_u64() | ||
| } | ||
| #[inline] | ||
| fn to_i8(&self) -> i8 { | ||
| Self::to_f32(*self).to_i8() | ||
| } | ||
| #[inline] | ||
| fn to_u8(&self) -> u8 { | ||
| Self::to_f32(*self).to_u8() | ||
| } | ||
| #[inline] | ||
| fn to_i16(&self) -> i16 { | ||
| Self::to_f32(*self).to_i16() | ||
| } | ||
| #[inline] | ||
| fn to_u16(&self) -> u16 { | ||
| Self::to_f32(*self).to_u16() | ||
| } | ||
| #[inline] | ||
| fn to_i32(&self) -> i32 { | ||
| Self::to_f32(*self).to_i32() | ||
| } | ||
| #[inline] | ||
| fn to_u32(&self) -> u32 { | ||
| Self::to_f32(*self).to_u32() | ||
| } | ||
| #[inline] | ||
| fn to_bf16(&self) -> bf16 { | ||
| *self | ||
| } | ||
| #[inline] | ||
| fn to_f32(&self) -> f32 { | ||
| Self::to_f32(*self) | ||
| } | ||
| #[inline] | ||
| fn to_f64(&self) -> f64 { | ||
| Self::to_f64(*self) | ||
| } | ||
| #[inline] | ||
| fn to_bool(&self) -> bool { | ||
| *self != bf16::from_f32_const(0.0) | ||
| } | ||
| } | ||
| #[cfg(feature = "cubecl")] | ||
| impl ToElement for cubecl::flex32 { | ||
| #[inline] | ||
| fn to_i64(&self) -> i64 { | ||
| Self::to_f32(*self).to_i64() | ||
| } | ||
| #[inline] | ||
| fn to_u64(&self) -> u64 { | ||
| Self::to_f32(*self).to_u64() | ||
| } | ||
| #[inline] | ||
| fn to_i8(&self) -> i8 { | ||
| Self::to_f32(*self).to_i8() | ||
| } | ||
| #[inline] | ||
| fn to_u8(&self) -> u8 { | ||
| Self::to_f32(*self).to_u8() | ||
| } | ||
| #[inline] | ||
| fn to_i16(&self) -> i16 { | ||
| Self::to_f32(*self).to_i16() | ||
| } | ||
| #[inline] | ||
| fn to_u16(&self) -> u16 { | ||
| Self::to_f32(*self).to_u16() | ||
| } | ||
| #[inline] | ||
| fn to_i32(&self) -> i32 { | ||
| Self::to_f32(*self).to_i32() | ||
| } | ||
| #[inline] | ||
| fn to_u32(&self) -> u32 { | ||
| Self::to_f32(*self).to_u32() | ||
| } | ||
| #[inline] | ||
| fn to_f32(&self) -> f32 { | ||
| Self::to_f32(*self) | ||
| } | ||
| #[inline] | ||
| fn to_f64(&self) -> f64 { | ||
| Self::to_f64(*self) | ||
| } | ||
| #[inline] | ||
| fn to_bool(&self) -> bool { | ||
| *self != cubecl::flex32::from_f32(0.0) | ||
| } | ||
| } | ||
| impl ToElement for bool { | ||
| #[inline] | ||
| fn to_i64(&self) -> i64 { | ||
| *self as i64 | ||
| } | ||
| #[inline] | ||
| fn to_u64(&self) -> u64 { | ||
| *self as u64 | ||
| } | ||
| #[inline] | ||
| fn to_i8(&self) -> i8 { | ||
| *self as i8 | ||
| } | ||
| #[inline] | ||
| fn to_u8(&self) -> u8 { | ||
| *self as u8 | ||
| } | ||
| #[inline] | ||
| fn to_i16(&self) -> i16 { | ||
| *self as i16 | ||
| } | ||
| #[inline] | ||
| fn to_u16(&self) -> u16 { | ||
| *self as u16 | ||
| } | ||
| #[inline] | ||
| fn to_i32(&self) -> i32 { | ||
| *self as i32 | ||
| } | ||
| #[inline] | ||
| fn to_u32(&self) -> u32 { | ||
| *self as u32 | ||
| } | ||
| #[inline] | ||
| fn to_f32(&self) -> f32 { | ||
| self.to_u8() as f32 | ||
| } | ||
| #[inline] | ||
| fn to_f64(&self) -> f64 { | ||
| self.to_u8() as f64 | ||
| } | ||
| #[inline] | ||
| fn to_bool(&self) -> bool { | ||
| *self | ||
| } | ||
| } | ||
| mod tests { | ||
| #[allow(unused_imports)] | ||
| use super::*; | ||
| #[test] | ||
| fn to_element_float() { | ||
| let f32_toolarge = 1e39f64; | ||
| assert_eq!(f32_toolarge.to_f32(), f32::INFINITY); | ||
| assert_eq!((-f32_toolarge).to_f32(), f32::NEG_INFINITY); | ||
| assert_eq!((f32::MAX as f64).to_f32(), f32::MAX); | ||
| assert_eq!((-f32::MAX as f64).to_f32(), -f32::MAX); | ||
| assert_eq!(f64::INFINITY.to_f32(), f32::INFINITY); | ||
| assert_eq!((f64::NEG_INFINITY).to_f32(), f32::NEG_INFINITY); | ||
| assert!((f64::NAN).to_f32().is_nan()); | ||
| } | ||
| #[test] | ||
| #[should_panic] | ||
| fn to_element_signed_to_u8_underflow() { | ||
| let _x = (-1i8).to_u8(); | ||
| } | ||
| #[test] | ||
| #[should_panic] | ||
| fn to_element_signed_to_u16_underflow() { | ||
| let _x = (-1i8).to_u16(); | ||
| } | ||
| #[test] | ||
| #[should_panic] | ||
| fn to_element_signed_to_u32_underflow() { | ||
| let _x = (-1i8).to_u32(); | ||
| } | ||
| #[test] | ||
| #[should_panic] | ||
| fn to_element_signed_to_u64_underflow() { | ||
| let _x = (-1i8).to_u64(); | ||
| } | ||
| #[test] | ||
| #[should_panic] | ||
| fn to_element_signed_to_u128_underflow() { | ||
| let _x = (-1i8).to_u128(); | ||
| } | ||
| #[test] | ||
| #[should_panic] | ||
| fn to_element_signed_to_usize_underflow() { | ||
| let _x = (-1i8).to_usize(); | ||
| } | ||
| #[test] | ||
| #[should_panic] | ||
| fn to_element_unsigned_to_u8_overflow() { | ||
| let _x = 256.to_u8(); | ||
| } | ||
| #[test] | ||
| #[should_panic] | ||
| fn to_element_unsigned_to_u16_overflow() { | ||
| let _x = 65_536.to_u16(); | ||
| } | ||
| #[test] | ||
| #[should_panic] | ||
| fn to_element_unsigned_to_u32_overflow() { | ||
| let _x = 4_294_967_296u64.to_u32(); | ||
| } | ||
| #[test] | ||
| #[should_panic] | ||
| fn to_element_unsigned_to_u64_overflow() { | ||
| let _x = 18_446_744_073_709_551_616u128.to_u64(); | ||
| } | ||
| #[test] | ||
| fn to_element_int_to_float() { | ||
| assert_eq!((-1).to_f32(), -1.0); | ||
| assert_eq!((-1).to_f64(), -1.0); | ||
| assert_eq!(255.to_f32(), 255.0); | ||
| assert_eq!(65_535.to_f64(), 65_535.0); | ||
| } | ||
| #[test] | ||
| fn to_element_float_to_int() { | ||
| assert_eq!((-1.0).to_i8(), -1); | ||
| assert_eq!(1.0.to_u8(), 1); | ||
| assert_eq!(1.8.to_u16(), 1); | ||
| assert_eq!(123.456.to_u32(), 123); | ||
| } | ||
| } |
| mod base; | ||
| /// Tensor element casting. | ||
| pub mod cast; | ||
| pub use base::*; |
| //! A module for indexing utility machinery. | ||
| use core::fmt::Debug; | ||
| /// Helper trait for implementing indexing with support for negative indices. | ||
| /// | ||
| /// # Example | ||
| /// ```rust | ||
| /// use burn_tensor::indexing::{AsIndex, canonicalize_dim}; | ||
| /// | ||
| /// fn example<I: AsIndex, const D: usize>(dim: I, size: usize) -> isize { | ||
| /// let dim: usize = canonicalize_dim(dim, D, false); | ||
| /// unimplemented!() | ||
| /// } | ||
| /// ``` | ||
| pub trait AsIndex: Debug + Copy + Sized { | ||
| /// Converts into a slice index. | ||
| fn index(self) -> isize; | ||
| } | ||
| impl AsIndex for usize { | ||
| fn index(self) -> isize { | ||
| self as isize | ||
| } | ||
| } | ||
| impl AsIndex for isize { | ||
| fn index(self) -> isize { | ||
| self | ||
| } | ||
| } | ||
| impl AsIndex for i64 { | ||
| fn index(self) -> isize { | ||
| self as isize | ||
| } | ||
| } | ||
| impl AsIndex for u64 { | ||
| fn index(self) -> isize { | ||
| self as isize | ||
| } | ||
| } | ||
| // Default integer type | ||
| impl AsIndex for i32 { | ||
| fn index(self) -> isize { | ||
| self as isize | ||
| } | ||
| } | ||
| impl AsIndex for u32 { | ||
| fn index(self) -> isize { | ||
| self as isize | ||
| } | ||
| } | ||
| impl AsIndex for i16 { | ||
| fn index(self) -> isize { | ||
| self as isize | ||
| } | ||
| } | ||
| impl AsIndex for u16 { | ||
| fn index(self) -> isize { | ||
| self as isize | ||
| } | ||
| } | ||
| impl AsIndex for i8 { | ||
| fn index(self) -> isize { | ||
| self as isize | ||
| } | ||
| } | ||
| impl AsIndex for u8 { | ||
| fn index(self) -> isize { | ||
| self as isize | ||
| } | ||
| } | ||
| /// Canonicalizes and bounds checks an index with negative indexing support. | ||
| /// | ||
| /// ## Arguments | ||
| /// | ||
| /// * `idx` - The index to canonicalize. | ||
| /// * `size` - The size of the index range. | ||
| /// * `wrap_scalar` - If true, pretend scalars have rank=1. | ||
| /// | ||
| /// ## Returns | ||
| /// | ||
| /// The canonicalized dimension index. | ||
| /// | ||
| /// ## Panics | ||
| /// | ||
| /// * If `wrap_scalar` is false and the tensor has no dimensions. | ||
| /// * If the dimension index is out of range. | ||
| #[must_use] | ||
| pub fn canonicalize_index<Index>(idx: Index, size: usize, wrap_scalar: bool) -> usize | ||
| where | ||
| Index: AsIndex, | ||
| { | ||
| canonicalize_named_index("index", "size", idx, size, wrap_scalar) | ||
| } | ||
| /// Canonicalizes and bounds checks a dimension index with negative indexing support. | ||
| /// | ||
| /// ## Arguments | ||
| /// | ||
| /// * `idx` - The dimension index to canonicalize. | ||
| /// * `rank` - The number of dimensions. | ||
| /// * `wrap_scalar` - If true, pretend scalars have rank=1. | ||
| /// | ||
| /// ## Returns | ||
| /// | ||
| /// The canonicalized dimension index. | ||
| /// | ||
| /// ## Panics | ||
| /// | ||
| /// * If `wrap_scalar` is false and the tensor has no dimensions. | ||
| /// * If the dimension index is out of range. | ||
| #[must_use] | ||
| pub fn canonicalize_dim<Dim>(idx: Dim, rank: usize, wrap_scalar: bool) -> usize | ||
| where | ||
| Dim: AsIndex, | ||
| { | ||
| canonicalize_named_index("dimension index", "rank", idx, rank, wrap_scalar) | ||
| } | ||
| /// Canonicalizes and bounds checks an index with negative indexing support. | ||
| /// | ||
| /// ## Arguments | ||
| /// | ||
| /// * `name` - The name of the index (for error messages). | ||
| /// * `size_name` - The name of the size (for error messages). | ||
| /// * `idx` - The index to canonicalize. | ||
| /// * `size` - The size of the index range. | ||
| /// * `wrap_scalar` - If true, treat 0-size ranges as having size 1. | ||
| /// | ||
| /// ## Returns | ||
| /// | ||
| /// The canonicalized index. | ||
| /// | ||
| /// ## Panics | ||
| /// | ||
| /// * If `wrap_scalar` is false and the size is 0. | ||
| /// * If the index is out of range for the dimension size. | ||
| #[inline(always)] | ||
| #[must_use] | ||
| fn canonicalize_named_index<I>( | ||
| name: &str, | ||
| size_name: &str, | ||
| idx: I, | ||
| size: usize, | ||
| wrap_scalar: bool, | ||
| ) -> usize | ||
| where | ||
| I: AsIndex, | ||
| { | ||
| let idx = idx.index(); | ||
| let rank = if size > 0 { | ||
| size | ||
| } else { | ||
| if !wrap_scalar { | ||
| panic!("{name} {idx} used when {size_name} is 0"); | ||
| } | ||
| 1 | ||
| }; | ||
| if idx >= 0 && (idx as usize) < rank { | ||
| return idx as usize; | ||
| } | ||
| let _idx = if idx < 0 { idx + rank as isize } else { idx }; | ||
| if _idx < 0 || (_idx as usize) >= rank { | ||
| let rank = rank as isize; | ||
| let lower = -rank; | ||
| let upper = rank - 1; | ||
| panic!("{name} {idx} out of range: ({lower}..={upper})"); | ||
| } | ||
| _idx as usize | ||
| } | ||
| /// Wraps a dimension index to be within the bounds of the dimension size. | ||
| /// | ||
| /// ## Arguments | ||
| /// | ||
| /// * `idx` - The dimension index to wrap. | ||
| /// * `size` - The size of the dimension. | ||
| /// | ||
| /// ## Returns | ||
| /// | ||
| /// The positive wrapped dimension index. | ||
| #[inline] | ||
| #[must_use] | ||
| pub fn wrap_index<I>(idx: I, size: usize) -> usize | ||
| where | ||
| I: AsIndex, | ||
| { | ||
| if size == 0 { | ||
| return 0; // Avoid modulo by zero | ||
| } | ||
| let wrapped = idx.index().rem_euclid(size as isize); | ||
| if wrapped < 0 { | ||
| (wrapped + size as isize) as usize | ||
| } else { | ||
| wrapped as usize | ||
| } | ||
| } | ||
| /// Compute the ravel index for the given coordinates. | ||
| /// | ||
| /// This returns the row-major order raveling: | ||
| /// * `strides[-1] = 1` | ||
| /// * `strides[i] = strides[i+1] * dims[i+1]` | ||
| /// * `dim_strides = coords * strides` | ||
| /// * `ravel = sum(dim_strides)` | ||
| /// | ||
| /// # Arguments | ||
| /// - `indices`: the index for each dimension; must be the same length as `shape`. | ||
| /// - `shape`: the shape of each dimension; be the same length as `indices`. | ||
| /// | ||
| /// # Returns | ||
| /// - the ravel offset index. | ||
| pub fn ravel_index<I: AsIndex>(indices: &[I], shape: &[usize]) -> usize { | ||
| assert_eq!( | ||
| shape.len(), | ||
| indices.len(), | ||
| "Coordinate rank mismatch: expected {}, got {}", | ||
| shape.len(), | ||
| indices.len(), | ||
| ); | ||
| let mut ravel_idx = 0; | ||
| let mut stride = 1; | ||
| for (i, &dim) in shape.iter().enumerate().rev() { | ||
| let coord = canonicalize_index(indices[i], dim, false); | ||
| ravel_idx += coord * stride; | ||
| stride *= dim; | ||
| } | ||
| ravel_idx | ||
| } | ||
| #[cfg(test)] | ||
| #[allow(clippy::identity_op, reason = "useful for clarity")] | ||
| mod tests { | ||
| use super::*; | ||
| use alloc::vec; | ||
| #[test] | ||
| fn test_ravel() { | ||
| let shape = vec![2, 3, 4, 5]; | ||
| assert_eq!(ravel_index(&[0, 0, 0, 0], &shape), 0); | ||
| assert_eq!( | ||
| ravel_index(&[1, 2, 3, 4], &shape), | ||
| 1 * (3 * 4 * 5) + 2 * (4 * 5) + 3 * 5 + 4 | ||
| ); | ||
| } | ||
| #[test] | ||
| fn test_wrap_idx() { | ||
| assert_eq!(wrap_index(0, 3), 0_usize); | ||
| assert_eq!(wrap_index(3, 3), 0_usize); | ||
| assert_eq!(wrap_index(2 * 3, 3), 0_usize); | ||
| assert_eq!(wrap_index(0 - 3, 3), 0_usize); | ||
| assert_eq!(wrap_index(0 - 2 * 3, 3), 0_usize); | ||
| assert_eq!(wrap_index(1, 3), 1_usize); | ||
| assert_eq!(wrap_index(1 + 3, 3), 1_usize); | ||
| assert_eq!(wrap_index(1 + 2 * 3, 3), 1_usize); | ||
| assert_eq!(wrap_index(1 - 3, 3), 1_usize); | ||
| assert_eq!(wrap_index(1 - 2 * 3, 3), 1_usize); | ||
| assert_eq!(wrap_index(2, 3), 2_usize); | ||
| assert_eq!(wrap_index(2 + 3, 3), 2_usize); | ||
| assert_eq!(wrap_index(2 + 2 * 3, 3), 2_usize); | ||
| assert_eq!(wrap_index(2 - 3, 3), 2_usize); | ||
| assert_eq!(wrap_index(2 - 2 * 3, 3), 2_usize); | ||
| } | ||
| #[test] | ||
| fn test_canonicalize_dim() { | ||
| let wrap_scalar = false; | ||
| assert_eq!(canonicalize_dim(0, 3, wrap_scalar), 0_usize); | ||
| assert_eq!(canonicalize_dim(1, 3, wrap_scalar), 1_usize); | ||
| assert_eq!(canonicalize_dim(2, 3, wrap_scalar), 2_usize); | ||
| assert_eq!(canonicalize_dim(-1, 3, wrap_scalar), (3 - 1) as usize); | ||
| assert_eq!(canonicalize_dim(-2, 3, wrap_scalar), (3 - 2) as usize); | ||
| assert_eq!(canonicalize_dim(-3, 3, wrap_scalar), (3 - 3) as usize); | ||
| let wrap_scalar = true; | ||
| assert_eq!(canonicalize_dim(0, 0, wrap_scalar), 0); | ||
| assert_eq!(canonicalize_dim(-1, 0, wrap_scalar), 0); | ||
| } | ||
| #[test] | ||
| #[should_panic = "dimension index 0 used when rank is 0"] | ||
| fn test_canonicalize_dim_error_no_dims() { | ||
| let _d = canonicalize_dim(0, 0, false); | ||
| } | ||
| #[test] | ||
| #[should_panic = "dimension index 3 out of range: (-3..=2)"] | ||
| fn test_canonicalize_dim_error_too_big() { | ||
| let _d = canonicalize_dim(3, 3, false); | ||
| } | ||
| #[test] | ||
| #[should_panic = "dimension index -4 out of range: (-3..=2)"] | ||
| fn test_canonicalize_dim_error_too_small() { | ||
| let _d = canonicalize_dim(-4, 3, false); | ||
| } | ||
| #[test] | ||
| fn test_canonicalize_index() { | ||
| let wrap_scalar = false; | ||
| assert_eq!(canonicalize_index(0, 3, wrap_scalar), 0_usize); | ||
| assert_eq!(canonicalize_index(1, 3, wrap_scalar), 1_usize); | ||
| assert_eq!(canonicalize_index(2, 3, wrap_scalar), 2_usize); | ||
| assert_eq!(canonicalize_index(-1, 3, wrap_scalar), (3 - 1) as usize); | ||
| assert_eq!(canonicalize_index(-2, 3, wrap_scalar), (3 - 2) as usize); | ||
| assert_eq!(canonicalize_index(-3, 3, wrap_scalar), (3 - 3) as usize); | ||
| let wrap_scalar = true; | ||
| assert_eq!(canonicalize_index(0, 0, wrap_scalar), 0); | ||
| assert_eq!(canonicalize_index(-1, 0, wrap_scalar), 0); | ||
| } | ||
| #[test] | ||
| #[should_panic = "index 3 out of range: (-3..=2)"] | ||
| fn test_canonicalize_index_error_too_big() { | ||
| let _d = canonicalize_index(3, 3, false); | ||
| } | ||
| } |
| use core::any::TypeId; | ||
| use crate::{Bytes, Element, quantization::unpack_q_to_i8s}; | ||
| use alloc::vec::Vec; | ||
| use super::{QParams, QuantLevel, QuantScheme, QuantStore, QuantValue}; | ||
| /// Quantized data bytes representation. | ||
| /// | ||
| /// # Notes | ||
| /// 1) The quantized values are packed into 32-bit unsigned integers. For example, int8 | ||
| /// quantized values pack 4 grouped values into a single `u32`. When unpacking these values, | ||
| /// we make sure to retrieve only the meaningful values (and ignore the alignment padding). | ||
| /// 2) Quantization parameters are appended to the tensor data. | ||
| /// As such, the last bytes always correspond to the scale parameter. | ||
| /// If the quantization scheme includes an offset (zero-point) parameter, it is next to last. | ||
| pub struct QuantizedBytes { | ||
| /// The quantized values and quantization parameters represented as bytes. | ||
| pub bytes: Bytes, | ||
| /// The quantization scheme. | ||
| pub scheme: QuantScheme, | ||
| /// The number of quantized elements. | ||
| pub num_elements: usize, | ||
| } | ||
| impl QuantizedBytes { | ||
| /// Creates a new quantized bytes representation. | ||
| pub fn new<E: Element>(value: Vec<E>, scheme: QuantScheme, scales: &[f32]) -> Self { | ||
| let num_elements = value.len(); | ||
| // Only used for 8-bit quantization data comparison in tests | ||
| if TypeId::of::<E>() != TypeId::of::<i8>() { | ||
| panic!("Invalid quantized type"); | ||
| } | ||
| // Re-interpret `Vec<E>` as `Vec<i8>` with `Vec::from_raw_parts` | ||
| let i8s: Vec<i8> = bytemuck::allocation::cast_vec(value); | ||
| let mut bytes = Bytes::from_elems(i8s); | ||
| match scheme.level { | ||
| QuantLevel::Tensor => { | ||
| let scale_bytes = bytemuck::bytes_of(&scales[0]); | ||
| bytes.extend_from_byte_slice_aligned(scale_bytes, align_of::<f32>()); | ||
| } | ||
| QuantLevel::Block(_block_size) => { | ||
| let mut scale_bytes = Vec::with_capacity(size_of_val(scales)); | ||
| for scale in scales { | ||
| scale_bytes.extend_from_slice(bytemuck::bytes_of(scale)); | ||
| } | ||
| bytes.extend_from_byte_slice_aligned(scale_bytes.as_slice(), align_of::<f32>()); | ||
| } | ||
| } | ||
| Self { | ||
| bytes, | ||
| scheme, | ||
| num_elements, | ||
| } | ||
| } | ||
| /// Returns the int8 quantized values with the quantization parameters. | ||
| pub fn into_vec_i8(self) -> (Vec<i8>, QParams<Vec<f32>>) { | ||
| let (values, (qparams, num_params)) = self.split_values_off(); | ||
| // Quantization parameters are added at the end of the tensor data. | ||
| // As such, the last bytes always correspond to the scale parameter(s). | ||
| // For example, per-block quantization can have multiple parameters for a single tensor: | ||
| // [scale, scale, scale, ...] | ||
| let scale_size = core::mem::size_of::<f32>(); // scale is stored as f32 | ||
| let qparams_bytes: &[u8] = bytemuck::cast_slice(&qparams); | ||
| let total_bytes = qparams_bytes.len(); | ||
| let scales_size = scale_size * num_params; | ||
| let scales = bytemuck::cast_slice(&qparams_bytes[total_bytes - scales_size..]).to_vec(); | ||
| (values, QParams { scales }) | ||
| } | ||
| fn split_i8_values(self, num_params: usize) -> (Vec<i8>, Vec<u32>) { | ||
| let mut values = read_bytes_to_i8(self.bytes); | ||
| let scale_size = num_params * size_of::<f32>(); | ||
| let values_end = values.len() - scale_size; | ||
| let qparams = values.split_off(values_end); | ||
| let qparams = if (qparams.as_ptr() as usize).is_multiple_of(4) { | ||
| let mut qparams = core::mem::ManuallyDrop::new(qparams); | ||
| unsafe { | ||
| Vec::<u32>::from_raw_parts( | ||
| qparams.as_mut_ptr() as _, | ||
| qparams.len() / 4, | ||
| qparams.capacity() / 4, | ||
| ) | ||
| } | ||
| } else { | ||
| #[cfg(target_endian = "little")] | ||
| { | ||
| // SAFETY: quantized bytes representation is created from packed u32 values in little endian | ||
| bytemuck::cast_vec(qparams) | ||
| } | ||
| #[cfg(target_endian = "big")] | ||
| { | ||
| crate::quantization::pack_i8s_to_u32s(bytemuck::cast_vec(qparams)) | ||
| } | ||
| }; | ||
| (values, qparams) | ||
| } | ||
| /// Splits the quantized values of the tensor from the quantization parameters. | ||
| /// | ||
| /// Returns the values in i8 and a newly allocated vector containing the quantization parameters. | ||
| fn split_values_off(self) -> (Vec<i8>, (Vec<u32>, usize)) { | ||
| let num_params = match self.scheme.level { | ||
| QuantLevel::Tensor => 1, | ||
| QuantLevel::Block(block_size) => self.num_elements / block_size.num_elements(), | ||
| }; | ||
| let (values, qparams) = match self.scheme.store { | ||
| QuantStore::Native => self.split_i8_values(num_params), | ||
| QuantStore::U32 => match self.scheme.value { | ||
| QuantValue::Q8F | QuantValue::Q8S => self.split_i8_values(num_params), | ||
| QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => { | ||
| let mut values = self.bytes.try_into_vec::<u32>().unwrap(); | ||
| let scale_size = num_params; // size of f32 same as u32 | ||
| let values_end = values.len() - scale_size; | ||
| let qparams = values.split_off(values_end); | ||
| // Sub-byte values are unpacked as i8s for value equality tests | ||
| let values = unpack_q_to_i8s(&values, self.num_elements, &self.scheme.value); | ||
| (values, qparams) | ||
| } | ||
| QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => { | ||
| unimplemented!("Not yet supported") | ||
| } | ||
| }, | ||
| }; | ||
| (values, (qparams, num_params)) | ||
| } | ||
| } | ||
| fn read_bytes_to_i8(bytes: Bytes) -> Vec<i8> { | ||
| match bytes.try_into_vec::<i8>() { | ||
| Ok(val) => val, | ||
| // Safety, | ||
| // | ||
| // `Vec<u8>` can be Re-interpreted as `Vec<i8>` since they share the same alignment. | ||
| Err(bytes) => unsafe { core::mem::transmute::<Vec<u8>, Vec<i8>>(bytes.to_vec()) }, | ||
| } | ||
| } | ||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
| use alloc::vec; | ||
| #[test] | ||
| fn should_pack_unpack_quantization_parameters_per_tensor_symmetric() { | ||
| // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] | ||
| let scale = 0.03937008; | ||
| let values = vec![0i8, 25, 51, 76, 102, 127]; | ||
| let q_bytes = QuantizedBytes::new( | ||
| values.clone(), | ||
| QuantScheme::default() | ||
| .with_value(QuantValue::Q8S) | ||
| .with_store(QuantStore::Native), | ||
| &[scale], | ||
| ); | ||
| let (q_values, qparams) = q_bytes.into_vec_i8(); | ||
| assert_eq!(qparams.scales, vec![scale]); | ||
| assert_eq!(q_values, values); | ||
| } | ||
| } |
| use crate::quantization::QuantValue; | ||
| use alloc::vec::Vec; | ||
| use num_traits::PrimInt; | ||
| /// Pack signed 8-bit integer values into a sequence of unsigned 32-bit integers. | ||
| pub fn pack_i8s_to_u32s(values: Vec<i8>) -> Vec<u32> { | ||
| // Shift and combine groups of four 8-bit values into a u32. | ||
| // Same as doing this: | ||
| // let result = (d_u8 & 0xFF) << 24 | (c_u8 & 0xFF) << 16 | (b_u8 & 0xFF) << 8 | (a_u8 & 0xFF); | ||
| #[cfg(target_endian = "big")] | ||
| { | ||
| values | ||
| .chunks(4) | ||
| .map(|x| { | ||
| x.iter() | ||
| .enumerate() | ||
| .fold(0u32, |acc, (i, x)| acc | (*x as u32 & 0xFF) << (i * 8)) | ||
| }) | ||
| .collect() | ||
| } | ||
| // The order of bytes in little endian matches the above description, we just need to | ||
| // handle padding when the number of values is not a factor of 4 | ||
| #[cfg(target_endian = "little")] | ||
| { | ||
| let mut values = values; | ||
| let remainder = values.len() % 4; | ||
| if remainder != 0 { | ||
| // Pad with zeros | ||
| values.extend(core::iter::repeat_n(0, 4 - remainder)); | ||
| } | ||
| let len = values.len() / 4; | ||
| let capacity = values.capacity() / 4; | ||
| // Pre-forget the old vec and re-interpret as u32 | ||
| let mut values = core::mem::ManuallyDrop::new(values); | ||
| let ptr = values.as_mut_ptr() as *mut u32; | ||
| unsafe { Vec::from_raw_parts(ptr, len, capacity) } | ||
| } | ||
| } | ||
| /// Unpack integer values into a sequence of signed 8-bit integers. | ||
| pub(crate) fn unpack_q_to_i8s<Q: PrimInt>( | ||
| values: &[Q], | ||
| numel: usize, | ||
| value: &QuantValue, | ||
| ) -> Vec<i8> { | ||
| let size_store = size_of::<Q>() * 8; | ||
| let size_quant = value.size_bits(); | ||
| let num_quants = size_store / size_quant; | ||
| let mask = Q::from((1 << size_quant) - 1).unwrap(); | ||
| let sign_shift = 8 - size_quant; // sign extension for sub-byte values | ||
| values | ||
| .iter() | ||
| .enumerate() | ||
| .flat_map(|(i, &packed)| { | ||
| // A single u32 could contain less than four 8-bit values... | ||
| let n = core::cmp::min(num_quants, numel - i * num_quants); | ||
| // Extract each 8-bit segment from u32 and cast back to i8 | ||
| // Same as doing this (when 4 values are fully packed): | ||
| // let a = (packed & 0xFF) as i8; | ||
| // let b = ((packed >> 8) & 0xFF) as i8; | ||
| // let c = ((packed >> 16) & 0xFF) as i8; | ||
| // let d = ((packed >> 24) & 0xFF) as i8; | ||
| (0..n).map(move |i| { | ||
| let raw = (packed >> (i * size_quant) & mask).to_u8().unwrap(); | ||
| ((raw << sign_shift) as i8) >> sign_shift | ||
| }) | ||
| }) | ||
| .collect() | ||
| } | ||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
| use alloc::vec; | ||
| #[test] | ||
| fn should_pack_i8s_to_u32() { | ||
| let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127]); | ||
| assert_eq!(packed, vec![2147287680]); | ||
| } | ||
| #[test] | ||
| fn should_pack_i8s_to_u32_padded() { | ||
| let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55]); | ||
| let packed_padded = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55, 0, 0, 0]); | ||
| assert_eq!(packed, vec![2147287680, 55]); | ||
| assert_eq!(packed, packed_padded); | ||
| } | ||
| #[test] | ||
| fn should_unpack_u32s_to_i8s() { | ||
| let unpacked = unpack_q_to_i8s(&[2147287680u32], 4, &QuantValue::Q8S); | ||
| assert_eq!(unpacked, vec![-128, 2, -3, 127]); | ||
| } | ||
| #[test] | ||
| fn should_unpack_u32s_to_i8s_padded() { | ||
| let unpacked = unpack_q_to_i8s(&[55u32], 1, &QuantValue::Q8S); | ||
| assert_eq!(unpacked, vec![55]); | ||
| } | ||
| #[test] | ||
| fn should_unpack_u32s_to_i8s_arange() { | ||
| let unpacked = unpack_q_to_i8s( | ||
| &[ | ||
| 0u32, 286331136, 286331153, 572657937, 572662306, 857874978, 858993459, 858993459, | ||
| 1145324612, 1145324612, 1431655748, 1431655765, 1717982549, 1717986918, 2003199590, | ||
| 2004318071, | ||
| ], | ||
| 128, | ||
| &QuantValue::Q4S, | ||
| ); | ||
| assert_eq!( | ||
| unpacked, | ||
| vec![ | ||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||
| 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, | ||
| 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, | ||
| 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, | ||
| 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 | ||
| ] | ||
| ); | ||
| } | ||
| } |
-1139
| use crate::indexing::ravel_index; | ||
| use crate::{AsIndex, Slice, SliceArg}; | ||
| use alloc::vec::Vec; | ||
| use core::{ | ||
| ops::{Deref, DerefMut, Index, IndexMut, Range}, | ||
| slice::{Iter, IterMut, SliceIndex}, | ||
| }; | ||
| use serde::{Deserialize, Serialize}; | ||
| /// Shape of a tensor. | ||
| #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] | ||
| pub struct Shape { | ||
| /// The dimensions of the tensor. | ||
| pub dims: Vec<usize>, | ||
| } | ||
| #[allow(missing_docs)] | ||
| #[derive(Debug, Clone, PartialEq, Eq)] | ||
| /// Error that can occur when attempting to modify shapes. | ||
| pub enum ShapeError { | ||
| /// The operands have different ranks. | ||
| RankMismatch { left: usize, right: usize }, | ||
| /// A pair of dimensions are incompatible for broadcasting. | ||
| IncompatibleDims { | ||
| left: usize, | ||
| right: usize, | ||
| dim: usize, | ||
| }, | ||
| /// Invalid dimension specified for the rank. | ||
| OutOfBounds { dim: usize, rank: usize }, | ||
| /// A pair of shapes are incompatible for the operation. | ||
| IncompatibleShapes { left: Shape, right: Shape }, | ||
| /// Invalid empty shape. | ||
| Empty, | ||
| } | ||
| impl Shape { | ||
| /// Constructs a new `Shape`. | ||
| pub fn new<const D: usize>(dims: [usize; D]) -> Self { | ||
| // For backward compat | ||
| Self { | ||
| dims: dims.to_vec(), | ||
| } | ||
| } | ||
| /// Returns the total number of elements of a tensor having this shape | ||
| pub fn num_elements(&self) -> usize { | ||
| self.dims.iter().product() | ||
| } | ||
| /// Returns the number of dimensions. | ||
| /// | ||
| /// Alias for `Shape::rank()`. | ||
| pub fn num_dims(&self) -> usize { | ||
| self.dims.len() | ||
| } | ||
| /// Returns the rank (the number of dimensions). | ||
| /// | ||
| /// Alias for `Shape::num_dims()`. | ||
| pub fn rank(&self) -> usize { | ||
| self.num_dims() | ||
| } | ||
| // For compat with dims: [usize; D] | ||
| /// Returns the dimensions of the tensor as an array. | ||
| pub fn dims<const D: usize>(&self) -> [usize; D] { | ||
| let mut dims = [1; D]; | ||
| dims[..D].copy_from_slice(&self.dims[..D]); | ||
| dims | ||
| } | ||
| /// Change the shape to one dimensional with the same number of elements. | ||
| pub fn flatten(mut self) -> Self { | ||
| self.dims = [self.num_elements()].into(); | ||
| self | ||
| } | ||
| /// Compute the ravel index for the given coordinates. | ||
| /// | ||
| /// This returns the row-major order raveling: | ||
| /// * `strides[-1] = 1` | ||
| /// * `strides[i] = strides[i+1] * dims[i+1]` | ||
| /// * `dim_strides = coords * strides` | ||
| /// * `ravel = sum(dim_strides)` | ||
| /// | ||
| /// # Arguments | ||
| /// - `indices`: the index for each dimension; must be the same length as `shape`. | ||
| /// | ||
| /// # Returns | ||
| /// - the ravel offset index. | ||
| pub fn ravel_index<I: AsIndex>(&self, indices: &[I]) -> usize { | ||
| ravel_index(indices, &self.dims) | ||
| } | ||
| /// Convert shape dimensions to full covering ranges (0..dim) for each dimension. | ||
| pub fn into_ranges(self) -> Vec<Range<usize>> { | ||
| self.into_iter().map(|d| 0..d).collect() | ||
| } | ||
| /// Converts slice arguments into an array of slice specifications for the shape. | ||
| /// | ||
| /// This method returns an array of `Slice` objects that can be used for slicing operations. | ||
| /// The slices are clamped to the shape's dimensions. Similar to `into_ranges()`, but | ||
| /// allows custom slice specifications instead of full ranges. | ||
| /// For creating complex slice specifications, use the [`s!`] macro. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `slices` - An array of slice specifications, where each element can be: | ||
| /// - A range (e.g., `2..5`) | ||
| /// - An index | ||
| /// - A `Slice` object | ||
| /// - The output of the [`s!`] macro for advanced slicing | ||
| /// | ||
| /// # Behavior | ||
| /// | ||
| /// - Supports partial and full slicing in any number of dimensions. | ||
| /// - Missing ranges are treated as full slices if D > D2. | ||
| /// - Handles negative indices by wrapping around from the end of the dimension. | ||
| /// - Clamps ranges to the shape's dimensions if they exceed the bounds. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// An array of `Slice` objects corresponding to the provided slice specifications, | ||
| /// clamped to the shape's actual dimensions. | ||
| /// | ||
| /// # Examples | ||
| /// | ||
| /// ```rust | ||
| /// use burn_tensor::backend::Backend; | ||
| /// use burn_tensor::{Tensor, Shape, Slice, s}; | ||
| /// | ||
| /// fn example<B: Backend>() { | ||
| /// // 1D slicing | ||
| /// let slices = Shape::new([4]).into_slices(1..4); | ||
| /// assert_eq!(slices[0].to_range(4), 1..3); | ||
| /// | ||
| /// // 2D slicing | ||
| /// let slices = Shape::new([3, 4]).into_slices(s![1..4, 0..2]); | ||
| /// assert_eq!(slices[0].to_range(3), 1..3); | ||
| /// assert_eq!(slices[1].to_range(4), 0..2); | ||
| /// | ||
| /// // Using negative indices | ||
| /// let slices = Shape::new([3]).into_slices(..-2); | ||
| /// assert_eq!(slices[0].to_range(3), 0..1); | ||
| /// | ||
| /// // Using the slice macro to select different ranges | ||
| /// let slices = Shape::new([2, 3, 4]).into_slices(s![.., 1..-1]); | ||
| /// assert_eq!(slices[0].to_range(2), 0..2); | ||
| /// assert_eq!(slices[1].to_range(3), 1..2); | ||
| /// } | ||
| /// ``` | ||
| /// | ||
| /// # See Also | ||
| /// | ||
| /// - [`s!`] - The recommended macro for creating slice specifications | ||
| /// - [`Tensor::slice`] - Apply slicing to a tensor | ||
| /// - [`Shape::into_ranges`] - Convert to full covering ranges | ||
| /// | ||
| /// [`s!`]: crate::s! | ||
| /// [`Tensor::slice`]: crate::Tensor::slice | ||
| pub fn into_slices<const D: usize, S>(self, slices: S) -> [Slice; D] | ||
| where | ||
| S: SliceArg<D>, | ||
| { | ||
| slices.into_slices(self) | ||
| } | ||
| /// Construct a vector of the dims. | ||
| pub fn to_vec(&self) -> Vec<usize> { | ||
| self.dims.clone() | ||
| } | ||
| /// Returns an iterator over the shape dimensions. | ||
| pub fn iter(&self) -> Iter<'_, usize> { | ||
| self.dims.iter() | ||
| } | ||
| /// Mutable iterator over the dimensions. | ||
| pub fn iter_mut(&mut self) -> IterMut<'_, usize> { | ||
| self.dims.iter_mut() | ||
| } | ||
| /// Borrow the underlying dimensions slice. | ||
| pub fn as_slice(&self) -> &[usize] { | ||
| &self.dims | ||
| } | ||
| /// Borrow the underlying dimensions slice mutably. | ||
| pub fn as_mut_slice(&mut self) -> &mut [usize] { | ||
| &mut self.dims | ||
| } | ||
| /// Insert a dimension of `size` at position `index`. | ||
| pub fn insert(&mut self, index: usize, size: usize) { | ||
| self.dims.insert(index, size); | ||
| } | ||
| /// Remove and return the dimension at position `index` from the shape. | ||
| pub fn remove(&mut self, index: usize) -> usize { | ||
| self.dims.remove(index) | ||
| } | ||
| /// Appends a dimension of `size` to the back of the shape. | ||
| pub fn push(&mut self, size: usize) { | ||
| self.dims.push(size) | ||
| } | ||
| /// Extend the shape with the content of another shape or iterator. | ||
| pub fn extend(&mut self, iter: impl IntoIterator<Item = usize>) { | ||
| self.dims.extend(iter) | ||
| } | ||
| /// Swap two dimensions in the shape. | ||
| pub fn swap(mut self, dim1: usize, dim2: usize) -> Result<Self, ShapeError> { | ||
| if dim1 > self.rank() { | ||
| return Err(ShapeError::OutOfBounds { | ||
| dim: dim1, | ||
| rank: self.rank(), | ||
| }); | ||
| } | ||
| if dim2 > self.rank() { | ||
| return Err(ShapeError::OutOfBounds { | ||
| dim: dim2, | ||
| rank: self.rank(), | ||
| }); | ||
| } | ||
| self.dims.swap(dim1, dim2); | ||
| Ok(self) | ||
| } | ||
| /// Reorder the shape dimensions according to the permutation of `axes`. | ||
| pub fn permute(mut self, axes: &[usize]) -> Result<Self, ShapeError> { | ||
| if axes.len() != self.rank() { | ||
| return Err(ShapeError::RankMismatch { | ||
| left: self.rank(), | ||
| right: axes.len(), | ||
| }); | ||
| } | ||
| debug_assert!(axes.iter().all(|i| i < &self.rank())); | ||
| self.dims = axes.iter().map(|&i| self.dims[i]).collect(); | ||
| Ok(self) | ||
| } | ||
| /// Repeated the specified `dim` a number of `times`. | ||
| pub fn repeat(mut self, dim: usize, times: usize) -> Result<Shape, ShapeError> { | ||
| if dim >= self.rank() { | ||
| return Err(ShapeError::OutOfBounds { | ||
| dim, | ||
| rank: self.rank(), | ||
| }); | ||
| } | ||
| self.dims[dim] *= times; | ||
| Ok(self) | ||
| } | ||
| /// Returns a new shape where the specified `dim` is reduced to size 1. | ||
| pub fn reduce(mut self, dim: usize) -> Result<Shape, ShapeError> { | ||
| if dim >= self.rank() { | ||
| return Err(ShapeError::OutOfBounds { | ||
| dim, | ||
| rank: self.rank(), | ||
| }); | ||
| } | ||
| self.dims[dim] = 1; | ||
| Ok(self) | ||
| } | ||
| /// Concatenates all shapes into a new one along the given dimension. | ||
| pub fn cat<'a, I>(shapes: I, dim: usize) -> Result<Self, ShapeError> | ||
| where | ||
| I: IntoIterator<Item = &'a Shape>, | ||
| { | ||
| let mut iter = shapes.into_iter(); | ||
| let first = iter.next().ok_or(ShapeError::Empty)?; | ||
| if dim >= first.rank() { | ||
| return Err(ShapeError::OutOfBounds { | ||
| dim, | ||
| rank: first.rank(), | ||
| }); | ||
| } | ||
| let mut shape = first.clone(); | ||
| for s in iter { | ||
| if s.rank() != shape.rank() { | ||
| return Err(ShapeError::RankMismatch { | ||
| left: shape.rank(), | ||
| right: s.rank(), | ||
| }); | ||
| } | ||
| if s[..dim] != shape[..dim] || s[dim + 1..] != shape[dim + 1..] { | ||
| return Err(ShapeError::IncompatibleShapes { | ||
| left: shape.clone(), | ||
| right: s.clone(), | ||
| }); | ||
| } | ||
| shape[dim] += s[dim]; | ||
| } | ||
| Ok(shape) | ||
| } | ||
| /// Compute the output shape from the given slices. | ||
| pub fn slice(mut self, slices: &[Slice]) -> Result<Self, ShapeError> { | ||
| if slices.len() > self.rank() { | ||
| return Err(ShapeError::RankMismatch { | ||
| left: self.rank(), | ||
| right: slices.len(), | ||
| }); | ||
| } | ||
| slices | ||
| .iter() | ||
| .zip(self.iter_mut()) | ||
| .for_each(|(slice, dim_size)| *dim_size = slice.output_size(*dim_size)); | ||
| Ok(self) | ||
| } | ||
| /// Compute the output shape for binary operations with broadcasting support. | ||
| /// | ||
| /// - Shapes must be of the same rank (missing dimensions are not handled automatically). | ||
| /// - Two dimensions are compatible if they are equal, or one of them is 1. | ||
| /// | ||
| /// For example, a shape `[1, 1, 2, 4]` can be broadcast into `[7, 6, 2, 4]` | ||
| /// because its axes are either equal or 1. On the other hand, a shape `[2, 2]` | ||
| /// can *not* be broadcast into `[2, 4]`. | ||
| pub fn broadcast(&self, other: &Self) -> Result<Self, ShapeError> { | ||
| Self::broadcast_many([self, other]) | ||
| } | ||
| /// Compute the broadcasted output shape across multiple input shapes. | ||
| /// | ||
| /// See also [broadcast](Self::broadcast). | ||
| pub fn broadcast_many<'a, I>(shapes: I) -> Result<Self, ShapeError> | ||
| where | ||
| I: IntoIterator<Item = &'a Shape>, | ||
| { | ||
| let mut iter = shapes.into_iter(); | ||
| let mut broadcasted = iter.next().ok_or(ShapeError::Empty)?.clone(); | ||
| let rank = broadcasted.rank(); | ||
| for shape in iter { | ||
| if shape.rank() != rank { | ||
| return Err(ShapeError::RankMismatch { | ||
| left: rank, | ||
| right: shape.rank(), | ||
| }); | ||
| } | ||
| for (dim, (d_lhs, &d_rhs)) in broadcasted.iter_mut().zip(shape.iter()).enumerate() { | ||
| match (*d_lhs, d_rhs) { | ||
| (a, b) if a == b => {} // same | ||
| (1, b) => *d_lhs = b, // broadcast to rhs | ||
| (_a, 1) => {} // keep existing dimension | ||
| _ => { | ||
| return Err(ShapeError::IncompatibleDims { | ||
| left: *d_lhs, | ||
| right: d_rhs, | ||
| dim, | ||
| }); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| Ok(broadcasted) | ||
| } | ||
| /// Expand this shape to match the target shape, following broadcasting rules. | ||
| pub fn expand(&self, target: Shape) -> Result<Shape, ShapeError> { | ||
| let target_rank = target.rank(); | ||
| if self.rank() > target_rank { | ||
| return Err(ShapeError::RankMismatch { | ||
| left: self.rank(), | ||
| right: target_rank, | ||
| }); | ||
| } | ||
| for (i, (dim_target, dim_self)) in target.iter().rev().zip(self.iter().rev()).enumerate() { | ||
| if dim_self != dim_target && *dim_self != 1 { | ||
| return Err(ShapeError::IncompatibleDims { | ||
| left: *dim_self, | ||
| right: *dim_target, | ||
| dim: target_rank - i - 1, | ||
| }); | ||
| } | ||
| } | ||
| Ok(target) | ||
| } | ||
| /// Reshape this shape to the target shape. | ||
| pub fn reshape(&self, target: Shape) -> Result<Shape, ShapeError> { | ||
| if self.num_elements() != target.num_elements() { | ||
| return Err(ShapeError::IncompatibleShapes { | ||
| left: self.clone(), | ||
| right: target, | ||
| }); | ||
| } | ||
| Ok(target) | ||
| } | ||
| } | ||
| /// Compute the output shape for matrix multiplication with broadcasting support. | ||
| /// | ||
| /// The last two dimensions are treated as matrices, while preceding dimensions | ||
| /// follow broadcast semantics similar to elementwise operations. | ||
| pub fn calculate_matmul_output(lhs: &Shape, rhs: &Shape) -> Result<Shape, ShapeError> { | ||
| let rank = lhs.rank(); | ||
| if rank != rhs.rank() { | ||
| return Err(ShapeError::RankMismatch { | ||
| left: rank, | ||
| right: rhs.rank(), | ||
| }); | ||
| } | ||
| if lhs[rank - 1] != rhs[rank - 2] { | ||
| return Err(ShapeError::IncompatibleShapes { | ||
| left: lhs.clone(), | ||
| right: rhs.clone(), | ||
| }); | ||
| } | ||
| let mut shape = if rank > 2 { | ||
| // Broadcast leading dims | ||
| Shape::from(&lhs[..rank - 2]).broadcast(&Shape::from(&rhs[..rank - 2]))? | ||
| } else { | ||
| Shape::new([]) | ||
| }; | ||
| shape.extend([lhs[rank - 2], rhs[rank - 1]]); | ||
| Ok(shape) | ||
| } | ||
| impl IntoIterator for Shape { | ||
| type Item = usize; | ||
| type IntoIter = alloc::vec::IntoIter<Self::Item>; | ||
| fn into_iter(self) -> Self::IntoIter { | ||
| self.dims.into_iter() | ||
| } | ||
| } | ||
| impl<Idx> Index<Idx> for Shape | ||
| where | ||
| Idx: SliceIndex<[usize]>, | ||
| { | ||
| type Output = Idx::Output; | ||
| fn index(&self, index: Idx) -> &Self::Output { | ||
| &self.dims[index] | ||
| } | ||
| } | ||
| impl<Idx> IndexMut<Idx> for Shape | ||
| where | ||
| Idx: SliceIndex<[usize]>, | ||
| { | ||
| fn index_mut(&mut self, index: Idx) -> &mut Self::Output { | ||
| &mut self.dims[index] | ||
| } | ||
| } | ||
| // Allow `&shape` to behave like a slice `&[usize]` directly | ||
| impl Deref for Shape { | ||
| type Target = [usize]; | ||
| fn deref(&self) -> &Self::Target { | ||
| &self.dims | ||
| } | ||
| } | ||
| // Allow `&shape` to behave like a mut slice `&mut [usize]` directly | ||
| impl DerefMut for Shape { | ||
| fn deref_mut(&mut self) -> &mut Self::Target { | ||
| &mut self.dims | ||
| } | ||
| } | ||
| // Conversion sugar | ||
| impl<const D: usize> From<[usize; D]> for Shape { | ||
| fn from(dims: [usize; D]) -> Self { | ||
| Shape::new(dims) | ||
| } | ||
| } | ||
| impl<const D: usize> From<[i64; D]> for Shape { | ||
| fn from(dims: [i64; D]) -> Self { | ||
| Shape { | ||
| dims: dims.into_iter().map(|d| d as usize).collect(), | ||
| } | ||
| } | ||
| } | ||
| impl<const D: usize> From<[i32; D]> for Shape { | ||
| fn from(dims: [i32; D]) -> Self { | ||
| Shape { | ||
| dims: dims.into_iter().map(|d| d as usize).collect(), | ||
| } | ||
| } | ||
| } | ||
| impl From<&[usize]> for Shape { | ||
| fn from(dims: &[usize]) -> Self { | ||
| Shape { dims: dims.into() } | ||
| } | ||
| } | ||
| impl From<Vec<i64>> for Shape { | ||
| fn from(shape: Vec<i64>) -> Self { | ||
| Self { | ||
| dims: shape.into_iter().map(|d| d as usize).collect(), | ||
| } | ||
| } | ||
| } | ||
| impl From<Vec<u64>> for Shape { | ||
| fn from(shape: Vec<u64>) -> Self { | ||
| Self { | ||
| dims: shape.into_iter().map(|d| d as usize).collect(), | ||
| } | ||
| } | ||
| } | ||
| impl From<Vec<usize>> for Shape { | ||
| fn from(shape: Vec<usize>) -> Self { | ||
| Self { dims: shape } | ||
| } | ||
| } | ||
| impl From<&Vec<usize>> for Shape { | ||
| fn from(shape: &Vec<usize>) -> Self { | ||
| Self { | ||
| dims: shape.clone(), | ||
| } | ||
| } | ||
| } | ||
| impl From<Shape> for Vec<usize> { | ||
| fn from(shape: Shape) -> Self { | ||
| shape.dims | ||
| } | ||
| } | ||
| #[cfg(test)] | ||
| #[allow(clippy::identity_op, reason = "useful for clarity")] | ||
| mod tests { | ||
| use super::*; | ||
| use crate::s; | ||
| use alloc::vec; | ||
| #[test] | ||
| fn num_dims_and_rank() { | ||
| let dims = [2, 3, 4, 5]; | ||
| let shape = Shape::new(dims); | ||
| assert_eq!(4, shape.num_dims()); | ||
| assert_eq!(4, shape.rank()); | ||
| } | ||
| #[test] | ||
| fn num_elements() { | ||
| let dims = [2, 3, 4, 5]; | ||
| let shape = Shape::new(dims); | ||
| assert_eq!(120, shape.num_elements()); | ||
| } | ||
| #[test] | ||
| fn test_shape_into_iter() { | ||
| let dims = [2, 3, 4, 5]; | ||
| let shape = Shape::new(dims); | ||
| assert_eq!(shape.into_iter().sum::<usize>(), 14); | ||
| } | ||
| #[test] | ||
| fn test_into_ranges() { | ||
| let dims = [2, 3, 4, 5]; | ||
| let shape = Shape::new(dims); | ||
| assert_eq!(shape.into_ranges(), vec![0..2, 0..3, 0..4, 0..5]); | ||
| } | ||
| #[test] | ||
| fn test_to_vec() { | ||
| let dims = [2, 3, 4, 5]; | ||
| let shape = Shape::new(dims); | ||
| assert_eq!(shape.to_vec(), vec![2, 3, 4, 5]); | ||
| } | ||
| #[allow(clippy::single_range_in_vec_init)] | ||
| #[test] | ||
| fn test_into_slices() { | ||
| let slices = Shape::new([3]).into_slices(1..4); | ||
| assert_eq!(slices[0].to_range(3), 1..3); | ||
| let slices = Shape::new([3, 4]).into_slices(s![1..4, 0..2]); | ||
| assert_eq!(slices[0].to_range(3), 1..3); | ||
| assert_eq!(slices[1].to_range(4), 0..2); | ||
| let slices = Shape::new([3]).into_slices(..-2); | ||
| assert_eq!(slices[0].to_range(3), 0..1); | ||
| let slices = Shape::new([2, 3, 4]).into_slices(s![.., 1..-1]); | ||
| assert_eq!(slices[0].to_range(2), 0..2); | ||
| assert_eq!(slices[1].to_range(3), 1..2); | ||
| let slices = Shape::new([2, 3, 4]).into_slices(s![..20, 2]); | ||
| assert_eq!(slices[0].to_range(2), 0..2); | ||
| assert_eq!(slices[1].to_range(3), 2..3); | ||
| } | ||
| #[test] | ||
| fn test_shape_index() { | ||
| let shape = Shape::new([2, 3, 4, 5]); | ||
| assert_eq!(shape[0], 2); | ||
| assert_eq!(shape[1], 3); | ||
| assert_eq!(shape[2], 4); | ||
| assert_eq!(shape[3], 5); | ||
| // Works with ranges | ||
| assert_eq!(shape[1..3], *&[3, 4]); | ||
| assert_eq!(shape[1..=2], *&[3, 4]); | ||
| assert_eq!(shape[..], *&[2, 3, 4, 5]); | ||
| } | ||
| #[test] | ||
| fn test_shape_slice_methods() { | ||
| let shape = Shape::new([2, 3, 4, 5]); | ||
| let dim = shape.first(); | ||
| assert_eq!(dim, Some(&2)); | ||
| let dim = shape.last(); | ||
| assert_eq!(dim, Some(&5)); | ||
| assert!(!shape.is_empty()); | ||
| let shape = Shape::new([]); | ||
| assert!(shape.is_empty()); | ||
| } | ||
| #[test] | ||
| fn test_shape_iter() { | ||
| let dims = [2, 3, 4, 5]; | ||
| let shape = Shape::new(dims); | ||
| for (d, sd) in dims.iter().zip(shape.iter()) { | ||
| assert_eq!(d, sd); | ||
| } | ||
| } | ||
| #[test] | ||
| fn test_shape_iter_mut() { | ||
| let mut shape = Shape::new([2, 3, 4, 5]); | ||
| for d in shape.iter_mut() { | ||
| *d += 1; | ||
| } | ||
| assert_eq!(&shape.dims, &[3, 4, 5, 6]); | ||
| } | ||
| #[test] | ||
| fn test_shape_as_slice() { | ||
| let dims = [2, 3, 4, 5]; | ||
| let shape = Shape::new(dims); | ||
| assert_eq!(shape.as_slice(), dims.as_slice()); | ||
| // Deref coercion | ||
| let shape_slice: &[usize] = &shape; | ||
| assert_eq!(shape_slice, *&[2, 3, 4, 5]); | ||
| } | ||
| #[test] | ||
| fn test_shape_as_mut_slice() { | ||
| let mut dims = [2, 3, 4, 5]; | ||
| let mut shape = Shape::new(dims); | ||
| let shape_mut = shape.as_mut_slice(); | ||
| assert_eq!(shape_mut, dims.as_mut_slice()); | ||
| shape_mut[1] = 6; | ||
| assert_eq!(shape_mut, &[2, 6, 4, 5]); | ||
| let mut shape = Shape::new(dims); | ||
| let shape = &mut shape[..]; | ||
| shape[1] = 6; | ||
| assert_eq!(shape, shape_mut) | ||
| } | ||
| #[test] | ||
| fn test_shape_flatten() { | ||
| let shape = Shape::new([2, 3, 4, 5]); | ||
| assert_eq!(shape.num_elements(), 120); | ||
| let shape = shape.flatten(); | ||
| assert_eq!(shape.num_elements(), 120); | ||
| assert_eq!(&shape.dims, &[120]); | ||
| } | ||
| #[test] | ||
| fn test_ravel() { | ||
| let shape = Shape::new([2, 3, 4, 5]); | ||
| assert_eq!(shape.ravel_index(&[0, 0, 0, 0]), 0); | ||
| assert_eq!( | ||
| shape.ravel_index(&[1, 2, 3, 4]), | ||
| 1 * (3 * 4 * 5) + 2 * (4 * 5) + 3 * 5 + 4 | ||
| ); | ||
| } | ||
| #[test] | ||
| fn test_shape_insert_remove_push() { | ||
| let dims = [2, 3, 4, 5]; | ||
| let mut shape = Shape::new(dims); | ||
| let size = 6; | ||
| shape.insert(1, size); | ||
| assert_eq!(shape, Shape::new([2, 6, 3, 4, 5])); | ||
| let removed = shape.remove(1); | ||
| assert_eq!(removed, size); | ||
| assert_eq!(shape, Shape::new(dims)); | ||
| shape.push(6); | ||
| assert_eq!(shape, Shape::new([2, 3, 4, 5, 6])); | ||
| } | ||
| #[test] | ||
| fn test_shape_swap_permute() { | ||
| let dims = [2, 3, 4, 5]; | ||
| let shape = Shape::new(dims); | ||
| let shape = shape.swap(1, 2).unwrap(); | ||
| assert_eq!(&shape.dims, &[2, 4, 3, 5]); | ||
| let shape = shape.permute(&[0, 2, 1, 3]).unwrap(); | ||
| assert_eq!(shape, Shape::new(dims)); | ||
| } | ||
| #[test] | ||
| #[should_panic] | ||
| fn test_shape_swap_out_of_bounds() { | ||
| let shape = Shape::new([2, 3, 4, 5]); | ||
| shape.swap(0, 4).unwrap(); | ||
| } | ||
| #[test] | ||
| #[should_panic] | ||
| fn test_shape_permute_incomplete() { | ||
| let shape = Shape::new([2, 3, 4, 5]); | ||
| shape.permute(&[0, 2, 1]).unwrap(); | ||
| } | ||
| #[test] | ||
| fn test_shape_repeat() { | ||
| let shape = Shape::new([2, 3, 4, 5]); | ||
| let out = shape.repeat(2, 3).unwrap(); | ||
| assert_eq!(out, Shape::new([2, 3, 12, 5])); | ||
| } | ||
| #[test] | ||
| fn test_shape_repeat_invalid() { | ||
| let shape = Shape::new([2, 3, 4, 5]); | ||
| let out = shape.repeat(5, 3); | ||
| assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 5, rank: 4 })); | ||
| } | ||
| #[test] | ||
| fn test_shape_reduce() { | ||
| let shape = Shape::new([2, 3, 4, 5]); | ||
| let out = shape.reduce(2).unwrap(); | ||
| assert_eq!(out, Shape::new([2, 3, 1, 5])); | ||
| } | ||
| #[test] | ||
| fn test_shape_reduce_invalid() { | ||
| let shape = Shape::new([2, 3, 4, 5]); | ||
| let out = shape.reduce(5); | ||
| assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 5, rank: 4 })); | ||
| } | ||
| #[test] | ||
| fn test_shape_broadcast_binary() { | ||
| let lhs = Shape::new([1, 1, 2, 4]); | ||
| let rhs = Shape::new([7, 6, 2, 1]); | ||
| let out = lhs.broadcast(&rhs).unwrap(); | ||
| assert_eq!(out, Shape::new([7, 6, 2, 4])); | ||
| } | ||
| #[test] | ||
| fn test_shape_broadcast_rank_mismatch() { | ||
| let lhs = Shape::new([1, 2, 4]); | ||
| let rhs = Shape::new([7, 6, 2, 4]); | ||
| let out = lhs.broadcast(&rhs); | ||
| assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 4 })); | ||
| } | ||
| #[test] | ||
| fn test_shape_broadcast_incompatible_dims() { | ||
| let lhs = Shape::new([1, 2, 2, 4]); | ||
| let rhs = Shape::new([7, 6, 2, 1]); | ||
| let out = lhs.broadcast(&rhs); | ||
| assert_eq!( | ||
| out, | ||
| Err(ShapeError::IncompatibleDims { | ||
| left: 2, | ||
| right: 6, | ||
| dim: 1 | ||
| }) | ||
| ); | ||
| } | ||
| #[test] | ||
| fn test_shape_broadcast_many() { | ||
| let s1 = Shape::new([1, 1, 2, 4]); | ||
| let s2 = Shape::new([7, 1, 2, 1]); | ||
| let s3 = Shape::new([7, 6, 1, 1]); | ||
| let out = Shape::broadcast_many([&s1, &s2, &s3]).unwrap(); | ||
| assert_eq!(out, Shape::new([7, 6, 2, 4])); | ||
| } | ||
| #[test] | ||
| fn test_shape_broadcast_many_rank_mismatch() { | ||
| let s1 = Shape::new([1, 1, 2, 4]); | ||
| let s2 = Shape::new([7, 1, 2, 1]); | ||
| let s3 = Shape::new([1, 6, 1]); | ||
| let out = Shape::broadcast_many([&s1, &s2, &s3]); | ||
| assert_eq!(out, Err(ShapeError::RankMismatch { left: 4, right: 3 })); | ||
| } | ||
| #[test] | ||
| fn test_shape_broadcast_many_incompatible_dims() { | ||
| let s1 = Shape::new([1, 1, 2, 4]); | ||
| let s2 = Shape::new([7, 1, 2, 1]); | ||
| let s3 = Shape::new([4, 6, 1, 1]); | ||
| let out = Shape::broadcast_many([&s1, &s2, &s3]); | ||
| assert_eq!( | ||
| out, | ||
| Err(ShapeError::IncompatibleDims { | ||
| left: 7, | ||
| right: 4, | ||
| dim: 0 | ||
| }) | ||
| ); | ||
| } | ||
| #[test] | ||
| fn test_shape_broadcast_many_empty() { | ||
| let out = Shape::broadcast_many(&[]); | ||
| assert_eq!(out, Err(ShapeError::Empty)); | ||
| } | ||
| #[test] | ||
| fn test_shape_matmul_2d() { | ||
| let lhs = Shape::new([2, 4]); | ||
| let rhs = Shape::new([4, 2]); | ||
| let out = calculate_matmul_output(&lhs, &rhs).unwrap(); | ||
| assert_eq!(out, Shape::new([2, 2])); | ||
| } | ||
| #[test] | ||
| fn test_shape_matmul_4d_broadcasted() { | ||
| let lhs = Shape::new([1, 3, 2, 4]); | ||
| let rhs = Shape::new([2, 1, 4, 2]); | ||
| let out = calculate_matmul_output(&lhs, &rhs).unwrap(); | ||
| assert_eq!(out, Shape::new([2, 3, 2, 2])); | ||
| } | ||
| #[test] | ||
| fn test_shape_matmul_invalid_rank() { | ||
| let lhs = Shape::new([3, 2, 4]); | ||
| let rhs = Shape::new([2, 1, 4, 2]); | ||
| let out = calculate_matmul_output(&lhs, &rhs); | ||
| assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 4 })); | ||
| } | ||
| #[test] | ||
| fn test_shape_matmul_invalid_shape() { | ||
| let lhs = Shape::new([1, 3, 2, 4]); | ||
| let rhs = Shape::new([2, 1, 3, 2]); | ||
| let out = calculate_matmul_output(&lhs, &rhs); | ||
| assert_eq!( | ||
| out, | ||
| Err(ShapeError::IncompatibleShapes { | ||
| left: lhs, | ||
| right: rhs | ||
| }) | ||
| ); | ||
| } | ||
| #[test] | ||
| fn test_shape_matmul_invalid_broadcast() { | ||
| let lhs = Shape::new([1, 3, 2, 4]); | ||
| let rhs = Shape::new([2, 2, 4, 2]); | ||
| let out = calculate_matmul_output(&lhs, &rhs); | ||
| assert_eq!( | ||
| out, | ||
| Err(ShapeError::IncompatibleDims { | ||
| left: 3, | ||
| right: 2, | ||
| dim: 1 | ||
| }) | ||
| ); | ||
| } | ||
| #[test] | ||
| fn test_shape_cat() { | ||
| let s1 = Shape::new([2, 3, 4, 5]); | ||
| let s2 = Shape::new([1, 3, 4, 5]); | ||
| let s3 = Shape::new([4, 3, 4, 5]); | ||
| let out = Shape::cat(&[s1, s2, s3], 0).unwrap(); | ||
| assert_eq!(out, Shape::new([7, 3, 4, 5])); | ||
| let s1 = Shape::new([2, 3, 4, 5]); | ||
| let s2 = Shape::new([2, 3, 2, 5]); | ||
| let s3 = Shape::new([2, 3, 1, 5]); | ||
| let out = Shape::cat(&[s1, s2, s3], 2).unwrap(); | ||
| assert_eq!(out, Shape::new([2, 3, 7, 5])); | ||
| } | ||
| #[test] | ||
| fn test_shape_cat_empty() { | ||
| let out = Shape::cat(&[], 0); | ||
| assert_eq!(out, Err(ShapeError::Empty)); | ||
| } | ||
| #[test] | ||
| fn test_shape_cat_dim_out_of_bounds() { | ||
| let s1 = Shape::new([2, 3, 4, 5]); | ||
| let s2 = Shape::new([2, 3, 4, 5]); | ||
| let out = Shape::cat(&[s1, s2], 4); | ||
| assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 4, rank: 4 })); | ||
| } | ||
| #[test] | ||
| fn test_shape_cat_rank_mismatch() { | ||
| let s1 = Shape::new([2, 3, 4, 5]); | ||
| let s2 = Shape::new([2, 3, 4, 5, 6]); | ||
| let out = Shape::cat(&[s1, s2], 0); | ||
| assert_eq!(out, Err(ShapeError::RankMismatch { left: 4, right: 5 })); | ||
| } | ||
| #[test] | ||
| fn test_shape_cat_incompatible_shapes() { | ||
| let s1 = Shape::new([2, 3, 4, 5]); | ||
| let s2 = Shape::new([1, 3, 4, 5]); | ||
| let out = Shape::cat(&[s1.clone(), s2.clone()], 1); | ||
| assert_eq!( | ||
| out, | ||
| Err(ShapeError::IncompatibleShapes { | ||
| left: s1, | ||
| right: s2 | ||
| }) | ||
| ); | ||
| } | ||
| #[test] | ||
| fn test_shape_slice_output_shape_basic() { | ||
| // Test basic slicing with step=1 | ||
| let slices = [ | ||
| Slice::new(0, Some(5), 1), // 5 elements | ||
| Slice::new(2, Some(8), 1), // 6 elements | ||
| ]; | ||
| let original_shape = Shape::new([10, 10, 10]); | ||
| let result = original_shape.slice(&slices).unwrap(); | ||
| assert_eq!(result, Shape::new([5, 6, 10])); | ||
| } | ||
| #[test] | ||
| fn test_shape_slice_output_shape_with_positive_steps() { | ||
| // Test slicing with various positive steps | ||
| let slices = [ | ||
| Slice::new(0, Some(10), 2), // [0,2,4,6,8] -> 5 elements | ||
| Slice::new(1, Some(9), 3), // [1,4,7] -> 3 elements | ||
| Slice::new(0, Some(7), 4), // [0,4] -> 2 elements | ||
| ]; | ||
| let original_shape = Shape::new([20, 20, 20, 30]); | ||
| let result = original_shape.slice(&slices).unwrap(); | ||
| assert_eq!(result, Shape::new([5, 3, 2, 30])); | ||
| } | ||
| #[test] | ||
| fn test_shape_slice_output_shape_with_negative_steps() { | ||
| // Test slicing with negative steps (backward iteration) | ||
| let slices = [ | ||
| Slice::new(0, Some(10), -1), // 10 elements traversed backward | ||
| Slice::new(2, Some(8), -2), // [7,5,3] -> 3 elements | ||
| ]; | ||
| let original_shape = Shape::new([20, 20, 20]); | ||
| let result = original_shape.slice(&slices).unwrap(); | ||
| assert_eq!(result, Shape::new([10, 3, 20])); | ||
| } | ||
| #[test] | ||
| fn test_shape_slice_output_shape_mixed_steps() { | ||
| // Test with a mix of positive, negative, and unit steps | ||
| let slices = [ | ||
| Slice::from_range_stepped(1..6, 1), // 5 elements | ||
| Slice::from_range_stepped(0..10, -3), // [9,6,3,0] -> 4 elements | ||
| Slice::from_range_stepped(2..14, 4), // [2,6,10] -> 3 elements | ||
| ]; | ||
| let original_shape = Shape::new([20, 20, 20]); | ||
| let result = original_shape.slice(&slices).unwrap(); | ||
| assert_eq!(result, Shape::new([5, 4, 3])); | ||
| } | ||
| #[test] | ||
| fn test_shape_slice_output_shape_partial_dims() { | ||
| // Test when slices has fewer dimensions than original shape | ||
| let slices = [ | ||
| Slice::from_range_stepped(2..7, 2), // [2,4,6] -> 3 elements | ||
| ]; | ||
| let original_shape = Shape::new([10, 20, 30, 40]); | ||
| let result = original_shape.slice(&slices).unwrap(); | ||
| assert_eq!(result, Shape::new([3, 20, 30, 40])); | ||
| } | ||
| #[test] | ||
| fn test_shape_slice_output_shape_edge_cases() { | ||
| // Test edge cases with small ranges and large steps | ||
| let slices = [ | ||
| Slice::from_range_stepped(0..1, 1), // Single element | ||
| Slice::from_range_stepped(0..10, 100), // Step larger than range -> 1 element | ||
| Slice::from_range_stepped(5..5, 1), // Empty range -> 0 elements | ||
| ]; | ||
| let original_shape = Shape::new([10, 20, 30]); | ||
| let result = original_shape.slice(&slices).unwrap(); | ||
| assert_eq!(result, Shape::new([1, 1, 0])); | ||
| } | ||
| #[test] | ||
| fn test_shape_slice_output_shape_empty() { | ||
| // Test with no slice infos (should return original shape) | ||
| let slices = []; | ||
| let original_shape = Shape::new([10, 20, 30]); | ||
| let result = original_shape.slice(&slices).unwrap(); | ||
| assert_eq!(result, Shape::new([10, 20, 30])); | ||
| } | ||
| #[test] | ||
| fn test_shape_slice_output_shape_uneven_division() { | ||
| // Test cases where range size doesn't divide evenly by step | ||
| let slices = [ | ||
| Slice::from_range_stepped(0..7, 3), // ceil(7/3) = 3 elements: [0,3,6] | ||
| Slice::from_range_stepped(0..11, 4), // ceil(11/4) = 3 elements: [0,4,8] | ||
| Slice::from_range_stepped(1..10, 5), // ceil(9/5) = 2 elements: [1,6] | ||
| ]; | ||
| let original_shape = Shape::new([20, 20, 20]); | ||
| let result = original_shape.slice(&slices).unwrap(); | ||
| assert_eq!(result, Shape::new([3, 3, 2])); | ||
| } | ||
| #[test] | ||
| fn test_shape_expand() { | ||
| let shape = Shape::new([1, 3, 1]); | ||
| let expanded = Shape::new([2, 3, 4]); | ||
| let out = shape.expand(expanded.clone()).unwrap(); | ||
| assert_eq!(out, expanded); | ||
| } | ||
| #[test] | ||
| fn test_shape_expand_higher_rank() { | ||
| let shape = Shape::new([1, 4]); | ||
| let expanded = Shape::new([2, 3, 4]); | ||
| let out = shape.expand(expanded.clone()).unwrap(); | ||
| assert_eq!(out, expanded); | ||
| } | ||
| #[test] | ||
| fn test_shape_expand_invalid_rank() { | ||
| let shape = Shape::new([1, 3, 1]); | ||
| let expanded = Shape::new([3, 4]); | ||
| let out = shape.expand(expanded); | ||
| assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 2 })); | ||
| } | ||
| #[test] | ||
| fn test_shape_expand_incompatible_dims() { | ||
| let shape = Shape::new([1, 3, 2]); | ||
| let expanded = Shape::new([2, 3, 4]); | ||
| let out = shape.expand(expanded); | ||
| assert_eq!( | ||
| out, | ||
| Err(ShapeError::IncompatibleDims { | ||
| left: 2, | ||
| right: 4, | ||
| dim: 2 | ||
| }) | ||
| ); | ||
| } | ||
| #[test] | ||
| fn test_shape_reshape() { | ||
| let shape = Shape::new([2, 3, 4, 5]); | ||
| let reshaped = Shape::new([1, 2, 12, 5]); | ||
| let out = shape.reshape(reshaped.clone()).unwrap(); | ||
| assert_eq!(out, reshaped); | ||
| } | ||
| #[test] | ||
| fn test_shape_reshape_invalid() { | ||
| let shape = Shape::new([2, 3, 4, 5]); | ||
| let reshaped = Shape::new([2, 2, 12, 5]); | ||
| let out = shape.clone().reshape(reshaped.clone()); | ||
| assert_eq!( | ||
| out, | ||
| Err(ShapeError::IncompatibleShapes { | ||
| left: shape, | ||
| right: reshaped | ||
| }) | ||
| ); | ||
| } | ||
| } |
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display