burn-tensor
Advanced tools
| use burn_backend::{Backend, Element, tensor::Device}; | ||
| use burn_std::DType; | ||
| /// Options for tensor creation. | ||
| /// | ||
| /// This struct allows specifying the `device` and/or data type (`dtype`) when creating a tensor. | ||
| #[derive(Debug, Clone)] | ||
| pub struct TensorCreationOptions<B: Backend> { | ||
| /// Device where the tensor will be created. | ||
| pub device: Device<B>, | ||
| /// Optional data type. | ||
| /// If `None`, the dtype will be inferred on creation from the backend's default dtype for the tensor kind. | ||
| pub dtype: Option<DType>, | ||
| } | ||
| impl<B: Backend> Default for TensorCreationOptions<B> { | ||
| /// Returns new options with the backend's default device. | ||
| fn default() -> Self { | ||
| Self::new(Default::default()) | ||
| } | ||
| } | ||
| impl<B: Backend> TensorCreationOptions<B> { | ||
| /// Create new options with a specific device. | ||
| /// | ||
| /// Data type will be inferred on creation from the backend's default dtype for the tensor kind. | ||
| pub fn new(device: Device<B>) -> Self { | ||
| Self { | ||
| device, | ||
| dtype: None, | ||
| } | ||
| } | ||
| /// Set the tensor creation data type. | ||
| pub fn with_dtype(mut self, dtype: DType) -> Self { | ||
| self.dtype = Some(dtype); | ||
| self | ||
| } | ||
| /// Set the tensor creation device. | ||
| pub fn with_device(mut self, device: Device<B>) -> Self { | ||
| self.device = device; | ||
| self | ||
| } | ||
| /// Create options with backend's default device and float dtype. | ||
| pub fn float() -> Self { | ||
| Self::default().with_dtype(<B::FloatElem as Element>::dtype()) | ||
| } | ||
| /// Create options with backend's default device and int dtype. | ||
| pub fn int() -> Self { | ||
| Self::default().with_dtype(<B::IntElem as Element>::dtype()) | ||
| } | ||
| /// Create options with backend's default device and bool dtype. | ||
| pub fn bool() -> Self { | ||
| Self::default().with_dtype(<B::BoolElem as Element>::dtype()) | ||
| } | ||
| /// Returns the tensor data type, or a provided default if not set. | ||
| /// | ||
| /// This is useful for cases where [`TensorCreationOptions`] may not have an explicit `dtype`. | ||
| pub fn dtype_or(&self, dtype: DType) -> DType { | ||
| self.dtype.unwrap_or(dtype) | ||
| } | ||
| } | ||
| impl<B: Backend> From<&Device<B>> for TensorCreationOptions<B> { | ||
| /// Convenience conversion from a reference to a device. | ||
| /// | ||
| /// Example: | ||
| /// ```rust | ||
| /// use burn_tensor::backend::Backend; | ||
| /// use burn_tensor::TensorCreationOptions; | ||
| /// | ||
| /// fn example<B: Backend>(device: B::Device) { | ||
| /// let options: TensorCreationOptions<B> = (&device).into(); | ||
| /// } | ||
| /// ``` | ||
| fn from(device: &Device<B>) -> Self { | ||
| TensorCreationOptions::new(device.clone()) | ||
| } | ||
| } | ||
| impl<B: Backend> From<(&Device<B>, DType)> for TensorCreationOptions<B> { | ||
| /// Convenience conversion for a specified `(&device, dtype)` tuple. | ||
| fn from(args: (&Device<B>, DType)) -> Self { | ||
| TensorCreationOptions::new(args.0.clone()).with_dtype(args.1) | ||
| } | ||
| } |
| { | ||
| "git": { | ||
| "sha1": "91dd62c92a4469b392e4d85ca6a11c969604d0d2" | ||
| "sha1": "67e7a966afc2d0f6e78b47c7814201884d277e89" | ||
| }, | ||
| "path_in_vcs": "crates/burn-tensor" | ||
| } |
+7
-3
@@ -15,3 +15,3 @@ # THIS FILE IS AUTOMATICALLY GENERATED BY CARGO | ||
| name = "burn-tensor" | ||
| version = "0.20.0-pre.6" | ||
| version = "0.20.0" | ||
| authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"] | ||
@@ -71,2 +71,6 @@ build = false | ||
| ] | ||
| tracing = [ | ||
| "burn-std/tracing", | ||
| "burn-backend/tracing", | ||
| ] | ||
@@ -78,7 +82,7 @@ [lib] | ||
| [dependencies.burn-backend] | ||
| version = "=0.20.0-pre.6" | ||
| version = "=0.20.0" | ||
| default-features = false | ||
| [dependencies.burn-std] | ||
| version = "=0.20.0-pre.6" | ||
| version = "=0.20.0" | ||
| default-features = false | ||
@@ -85,0 +89,0 @@ |
@@ -335,2 +335,20 @@ use crate::backend::Backend; | ||
| /// Applies the hard swish function element-wise. | ||
| /// | ||
| #[cfg_attr( | ||
| doc, | ||
| doc = r#" | ||
| $$ | ||
| \text{hard\_swish}\(x\) = x \cdot \text{hard\_sigmoid}(x) = x \cdot \max(0, \min(1, \frac{x}{6} + 0.5)) | ||
| $$ | ||
| "# | ||
| )] | ||
| #[cfg_attr( | ||
| not(doc), | ||
| doc = "`hard_swish(x) = x * hard_sigmoid(x) = x * max(0, min(1, x/6 + 0.5))`" | ||
| )] | ||
| pub fn hard_swish<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> { | ||
| tensor.clone().mul(hard_sigmoid(tensor, 1.0 / 6.0, 0.5)) | ||
| } | ||
| /// Applies the Mish function as described in the paper in | ||
@@ -337,0 +355,0 @@ /// [Mish: A Self Regularized Non-Monotonic Neural Activation Function](https://arxiv.org/abs/1908.08681). |
+180
-3
@@ -23,2 +23,24 @@ use crate::{Bool, Int, Shape, Tensor, TensorData, TensorPrimitive, backend::Backend}; | ||
| /// Create a boolean tensor from data on the given device. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `data` - The tensor data. | ||
| /// * `device` - The device on which the tensor will be allocated. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor. | ||
| /// | ||
| /// # Example | ||
| /// | ||
| /// ```rust | ||
| /// use burn_tensor::backend::Backend; | ||
| /// use burn_tensor::{Tensor, Bool}; | ||
| /// | ||
| /// fn example<B: Backend>() { | ||
| /// let device = Default::default(); | ||
| /// let tensor = Tensor::<B, 2, Bool>::from_bool([[true, false], [false, true]].into(), &device); | ||
| /// println!("{tensor}"); | ||
| /// } | ||
| /// ``` | ||
| pub fn from_bool(data: TensorData, device: &B::Device) -> Self { | ||
@@ -29,2 +51,20 @@ Self::new(B::bool_from_data(data.convert::<B::BoolElem>(), device)) | ||
| /// Convert the bool tensor into an int tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// An integer tensor where `true` is converted to `1` and `false` to `0`. | ||
| /// | ||
| /// # Example | ||
| /// | ||
| /// ```rust | ||
| /// use burn_tensor::backend::Backend; | ||
| /// use burn_tensor::{Tensor, Bool}; | ||
| /// | ||
| /// fn example<B: Backend>() { | ||
| /// let device = Default::default(); | ||
| /// let bool_tensor = Tensor::<B, 1, Bool>::from_bool([true, false, true].into(), &device); | ||
| /// let int_tensor = bool_tensor.int(); | ||
| /// println!("{int_tensor}"); // [1, 0, 1] | ||
| /// } | ||
| /// ``` | ||
| pub fn int(self) -> Tensor<B, D, Int> { | ||
@@ -35,2 +75,20 @@ Tensor::new(B::bool_into_int(self.primitive)) | ||
| /// Convert the bool tensor into a float tensor. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A float tensor where `true` is converted to `1.0` and `false` to `0.0`. | ||
| /// | ||
| /// # Example | ||
| /// | ||
| /// ```rust | ||
| /// use burn_tensor::backend::Backend; | ||
| /// use burn_tensor::{Tensor, Bool}; | ||
| /// | ||
| /// fn example<B: Backend>() { | ||
| /// let device = Default::default(); | ||
| /// let bool_tensor = Tensor::<B, 1, Bool>::from_bool([true, false, true].into(), &device); | ||
| /// let float_tensor = bool_tensor.float(); | ||
| /// println!("{float_tensor}"); // [1.0, 0.0, 1.0] | ||
| /// } | ||
| /// ``` | ||
| pub fn float(self) -> Tensor<B, D> { | ||
@@ -41,2 +99,16 @@ Tensor::new(TensorPrimitive::Float(B::bool_into_float(self.primitive))) | ||
| /// Inverses boolean values. | ||
| /// | ||
| /// # Example | ||
| /// | ||
| /// ```rust | ||
| /// use burn_tensor::backend::Backend; | ||
| /// use burn_tensor::{Tensor, Bool}; | ||
| /// | ||
| /// fn example<B: Backend>() { | ||
| /// let device = Default::default(); | ||
| /// let tensor = Tensor::<B, 2, Bool>::from_bool([[true, false], [false, true]].into(), &device); | ||
| /// let inverted = tensor.bool_not(); | ||
| /// println!("{inverted}"); // [[false, true], [true, false]] | ||
| /// } | ||
| /// ``` | ||
| pub fn bool_not(self) -> Self { | ||
@@ -46,3 +118,26 @@ Tensor::new(B::bool_not(self.primitive)) | ||
| /// Performs logical and (`&&`) on two boolean tensors | ||
| /// Performs logical and (`&&`) on two boolean tensors. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `rhs` - The right-hand side tensor for the AND operation. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor where each element is the result of `self[i] && rhs[i]`. | ||
| /// | ||
| /// # Example | ||
| /// | ||
| /// ```rust | ||
| /// use burn_tensor::backend::Backend; | ||
| /// use burn_tensor::{Tensor, Bool}; | ||
| /// | ||
| /// fn example<B: Backend>() { | ||
| /// let device = Default::default(); | ||
| /// let a = Tensor::<B, 2, Bool>::from_bool([[true, true], [false, false]].into(), &device); | ||
| /// let b = Tensor::<B, 2, Bool>::from_bool([[true, false], [true, false]].into(), &device); | ||
| /// let result = a.bool_and(b); | ||
| /// println!("{result}"); // [[true, false], [false, false]] | ||
| /// } | ||
| /// ``` | ||
| pub fn bool_and(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> { | ||
@@ -52,3 +147,26 @@ Tensor::new(B::bool_and(self.primitive, rhs.primitive)) | ||
| /// Performs logical or (`||`) on two boolean tensors | ||
| /// Performs logical or (`||`) on two boolean tensors. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `rhs` - The right-hand side tensor for the OR operation. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor where each element is the result of `self[i] || rhs[i]`. | ||
| /// | ||
| /// # Example | ||
| /// | ||
| /// ```rust | ||
| /// use burn_tensor::backend::Backend; | ||
| /// use burn_tensor::{Tensor, Bool}; | ||
| /// | ||
| /// fn example<B: Backend>() { | ||
| /// let device = Default::default(); | ||
| /// let a = Tensor::<B, 2, Bool>::from_bool([[true, true], [false, false]].into(), &device); | ||
| /// let b = Tensor::<B, 2, Bool>::from_bool([[true, false], [true, false]].into(), &device); | ||
| /// let result = a.bool_or(b); | ||
| /// println!("{result}"); // [[true, true], [true, false]] | ||
| /// } | ||
| /// ``` | ||
| pub fn bool_or(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> { | ||
@@ -58,3 +176,27 @@ Tensor::new(B::bool_or(self.primitive, rhs.primitive)) | ||
| /// Performs logical xor (`^`) on two boolean tensors | ||
| /// Performs logical xor (`^`) on two boolean tensors. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `rhs` - The right-hand side tensor for the XOR operation. | ||
| /// | ||
| /// # Returns | ||
| /// | ||
| /// A boolean tensor where each element is the result of `self[i] ^ rhs[i]`. | ||
| /// Returns `true` when exactly one of the operands is `true`. | ||
| /// | ||
| /// # Example | ||
| /// | ||
| /// ```rust | ||
| /// use burn_tensor::backend::Backend; | ||
| /// use burn_tensor::{Tensor, Bool}; | ||
| /// | ||
| /// fn example<B: Backend>() { | ||
| /// let device = Default::default(); | ||
| /// let a = Tensor::<B, 2, Bool>::from_bool([[true, true], [false, false]].into(), &device); | ||
| /// let b = Tensor::<B, 2, Bool>::from_bool([[true, false], [true, false]].into(), &device); | ||
| /// let result = a.bool_xor(b); | ||
| /// println!("{result}"); // [[false, true], [true, false]] | ||
| /// } | ||
| /// ``` | ||
| pub fn bool_xor(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> { | ||
@@ -70,2 +212,20 @@ Tensor::new(B::bool_xor(self.primitive, rhs.primitive)) | ||
| /// the non-zero elements in that dimension. | ||
| /// | ||
| /// # Example | ||
| /// | ||
| /// ```rust | ||
| /// use burn_tensor::backend::Backend; | ||
| /// use burn_tensor::{Tensor, Bool}; | ||
| /// | ||
| /// fn example<B: Backend>() { | ||
| /// let device = Default::default(); | ||
| /// let tensor = Tensor::<B, 2, Bool>::from_bool( | ||
| /// [[true, false, true], [false, true, false], [false, true, false]].into(), | ||
| /// &device, | ||
| /// ); | ||
| /// let indices = tensor.nonzero(); | ||
| /// println!("{}", indices[0]); // [0, 0, 1, 2] | ||
| /// println!("{}", indices[1]); // [0, 2, 1, 1] | ||
| /// } | ||
| /// ``` | ||
| pub fn nonzero(self) -> Vec<Tensor<B, 1, Int>> { | ||
@@ -104,2 +264,19 @@ try_read_sync(self.nonzero_async()) | ||
| /// result contains the indices of a non-zero element. | ||
| /// | ||
| /// # Example | ||
| /// | ||
| /// ```rust | ||
| /// use burn_tensor::backend::Backend; | ||
| /// use burn_tensor::{Tensor, Bool}; | ||
| /// | ||
| /// fn example<B: Backend>() { | ||
| /// let device = Default::default(); | ||
| /// let tensor = Tensor::<B, 2, Bool>::from_bool( | ||
| /// [[true, false, true], [false, true, false], [false, true, false]].into(), | ||
| /// &device, | ||
| /// ); | ||
| /// let indices = tensor.argwhere(); | ||
| /// println!("{indices}"); // [[0, 0], [0, 2], [1, 1], [2, 1]] | ||
| /// } | ||
| /// ``` | ||
| pub fn argwhere(self) -> Tensor<B, 2, Int> { | ||
@@ -106,0 +283,0 @@ try_read_sync(self.argwhere_async()) |
+14
-64
@@ -147,49 +147,2 @@ use crate::ops::FloatElem; | ||
| pub(crate) fn reshape_args_usize<const D1: usize, const D2: usize>( | ||
| original: &Shape, | ||
| target: &Shape, | ||
| ) -> Self { | ||
| let mut check = Self::Ok; | ||
| if original.num_elements() != target.num_elements() { | ||
| check = check.register( | ||
| "Reshape", | ||
| TensorError::new( | ||
| "The given shape doesn't have the same number of elements as the current \ | ||
| tensor.", | ||
| ) | ||
| .details(format!( | ||
| "Current shape: {:?}, target shape: {:?}.", | ||
| original.dims, target.dims | ||
| )), | ||
| ); | ||
| } | ||
| check | ||
| } | ||
| pub(crate) fn reshape_args_i64<const D: usize>(target: &[i64; D]) -> Self { | ||
| let mut check = Self::Ok; | ||
| if target.iter().any(|&dim| dim < -1) { | ||
| check = check.register( | ||
| "Reshape", | ||
| TensorError::new( | ||
| "The given shape cannot contain negative dimensions (other than -1).", | ||
| ) | ||
| .details(format!("Target shape: {target:?}.")), | ||
| ); | ||
| } | ||
| if target.iter().filter(|&x| x == &-1).count() > 1 { | ||
| check = check.register( | ||
| "Reshape", | ||
| TensorError::new("The given shape cannot contain more than one -1.") | ||
| .details(format!("Target shape: {target:?}.")), | ||
| ); | ||
| } | ||
| check | ||
| } | ||
| pub(crate) fn movedim_args_usize<const D: usize>(dim: usize) -> Self { | ||
@@ -1490,2 +1443,16 @@ let mut check = Self::Ok; | ||
| pub(crate) fn unwrap_shape_reshape(result: Result<Shape, burn_std::ShapeError>) -> Shape { | ||
| match result { | ||
| Ok(shape) => shape, | ||
| // `shape.reshape(new_shape)` should only return `ShapeError::Invalid`. | ||
| Err(burn_std::ShapeError::Invalid { reason }) => { | ||
| macros::check!({ | ||
| TensorCheck::Ok.register("Reshape", crate::check::TensorError::new(reason)) | ||
| }); | ||
| unreachable!() | ||
| } | ||
| Err(e) => panic!("{e:?}"), | ||
| } | ||
| } | ||
| #[cfg(test)] | ||
@@ -1498,19 +1465,2 @@ mod tests { | ||
| #[should_panic] | ||
| fn reshape_invalid_shape() { | ||
| check!(TensorCheck::reshape_args_usize::<2, 2>( | ||
| &Shape::new([2, 2]), | ||
| &Shape::new([1, 3]) | ||
| )); | ||
| } | ||
| #[test] | ||
| fn reshape_valid_shape() { | ||
| check!(TensorCheck::reshape_args_usize::<2, 2>( | ||
| &Shape::new([2, 2]), | ||
| &Shape::new([1, 4]) | ||
| )); | ||
| } | ||
| #[test] | ||
| #[should_panic] | ||
| fn index_range_exceed_dimension() { | ||
@@ -1517,0 +1467,0 @@ let slices = vec![Slice::from(0..2), Slice::from(0..4), Slice::from(1..8)]; |
+176
-2
@@ -205,2 +205,172 @@ use crate::AsIndex; | ||
| /// Applies element wise inverse sine operation. | ||
| /// | ||
| #[cfg_attr(doc, doc = r#"$y_i = \asin\(x_i\)$"#)] | ||
| #[cfg_attr(not(doc), doc = "`y_i = asin(x_i)`")] | ||
| /// | ||
| /// # Example | ||
| /// | ||
| /// ```rust | ||
| /// use burn_tensor::backend::Backend; | ||
| /// use burn_tensor::Tensor; | ||
| /// | ||
| /// fn example<B: Backend>() { | ||
| /// let device = Default::default(); | ||
| /// | ||
| /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device); | ||
| /// println!("{}", tensor.asin()); // [ 0.0000, -1.5708, 1.5708] | ||
| /// } | ||
| /// ``` | ||
| pub fn asin(self) -> Self { | ||
| Self::new(TensorPrimitive::Float(B::float_asin( | ||
| self.primitive.tensor(), | ||
| ))) | ||
| } | ||
| /// Applies element wise inverse hyperbolic sine operation. | ||
| /// | ||
| #[cfg_attr(doc, doc = r#"$y_i = \asinh\(x_i\)$"#)] | ||
| #[cfg_attr(not(doc), doc = "`y_i = asinh(x_i)`")] | ||
| /// | ||
| /// # Example | ||
| /// | ||
| /// ```rust | ||
| /// use burn_tensor::backend::Backend; | ||
| /// use burn_tensor::Tensor; | ||
| /// | ||
| /// fn example<B: Backend>() { | ||
| /// let device = Default::default(); | ||
| /// | ||
| /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device); | ||
| /// println!("{}", tensor.asinh()); // [ 0.0000, -0.8814, 0.8814] | ||
| /// } | ||
| /// ``` | ||
| pub fn asinh(self) -> Self { | ||
| Self::new(TensorPrimitive::Float(B::float_asinh( | ||
| self.primitive.tensor(), | ||
| ))) | ||
| } | ||
| /// Applies element wise inverse cosine operation. | ||
| /// | ||
| #[cfg_attr(doc, doc = r#"$y_i = \acos\(x_i\)$"#)] | ||
| #[cfg_attr(not(doc), doc = "`y_i = acos(x_i)`")] | ||
| /// | ||
| /// # Example | ||
| /// | ||
| /// ```rust | ||
| /// use burn_tensor::backend::Backend; | ||
| /// use burn_tensor::Tensor; | ||
| /// | ||
| /// fn example<B: Backend>() { | ||
| /// let device = Default::default(); | ||
| /// | ||
| /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device); | ||
| /// println!("{}", tensor.acos()); // [1.5708, 3.1416, 0.0] | ||
| /// } | ||
| /// ``` | ||
| pub fn acos(self) -> Self { | ||
| Self::new(TensorPrimitive::Float(B::float_acos( | ||
| self.primitive.tensor(), | ||
| ))) | ||
| } | ||
| /// Applies element wise inverse hyperbolic cosine operation. | ||
| /// | ||
| #[cfg_attr(doc, doc = r#"$y_i = \acosh\(x_i\)$"#)] | ||
| #[cfg_attr(not(doc), doc = "`y_i = acosh(x_i)`")] | ||
| /// | ||
| /// # Example | ||
| /// | ||
| /// ```rust | ||
| /// use burn_tensor::backend::Backend; | ||
| /// use burn_tensor::Tensor; | ||
| /// | ||
| /// fn example<B: Backend>() { | ||
| /// let device = Default::default(); | ||
| /// | ||
| /// let tensor = Tensor::<B, 1>::from_data([1.0, 2.0, 3.0], &device); | ||
| /// println!("{}", tensor.sinh()); // [0.0000, 1.3170, 1.7627] | ||
| /// } | ||
| /// ``` | ||
| pub fn acosh(self) -> Self { | ||
| Self::new(TensorPrimitive::Float(B::float_acosh( | ||
| self.primitive.tensor(), | ||
| ))) | ||
| } | ||
| /// Applies element wise inverse tangent operation. | ||
| /// | ||
| #[cfg_attr(doc, doc = r#"$y_i = \atan\(x_i\)$"#)] | ||
| #[cfg_attr(not(doc), doc = "`y_i = atan(x_i)`")] | ||
| /// | ||
| /// # Example | ||
| /// | ||
| /// ```rust | ||
| /// use burn_tensor::backend::Backend; | ||
| /// use burn_tensor::Tensor; | ||
| /// | ||
| /// fn example<B: Backend>() { | ||
| /// let device = Default::default(); | ||
| /// | ||
| /// let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device); | ||
| /// println!("{}", tensor.sinh()); // [ 0.0, -0.7854, 1.1071] | ||
| /// } | ||
| /// ``` | ||
| pub fn atan(self) -> Self { | ||
| Self::new(TensorPrimitive::Float(B::float_atan( | ||
| self.primitive.tensor(), | ||
| ))) | ||
| } | ||
| /// Applies element wise inverse hyperbolic tangent operation. | ||
| /// | ||
| #[cfg_attr(doc, doc = r#"$y_i = \atan\(x_i\)$"#)] | ||
| #[cfg_attr(not(doc), doc = "`y_i = atan(x_i)`")] | ||
| /// | ||
| /// # Example | ||
| /// | ||
| /// ```rust | ||
| /// use burn_tensor::backend::Backend; | ||
| /// use burn_tensor::Tensor; | ||
| /// | ||
| /// fn example<B: Backend>() { | ||
| /// let device = Default::default(); | ||
| /// | ||
| /// let tensor = Tensor::<B, 1>::from_data([0.0, -0.5, 0.5], &device); | ||
| /// println!("{}", tensor.sinh()); // [ 0.0, -0.5493, 0.5493] | ||
| /// } | ||
| /// ``` | ||
| pub fn atanh(self) -> Self { | ||
| Self::new(TensorPrimitive::Float(B::float_atanh( | ||
| self.primitive.tensor(), | ||
| ))) | ||
| } | ||
| /// Applies element wise inverse tangent operation using the signs of arguments to determine the correct quadrant. | ||
| /// | ||
| #[cfg_attr(doc, doc = r#"$z_i = \atan2\(y_i, x_i\)$"#)] | ||
| #[cfg_attr(not(doc), doc = "`z_i = atan2(y_i, x_i)`")] | ||
| /// | ||
| /// # Example | ||
| /// | ||
| /// ```rust | ||
| /// use burn_tensor::backend::Backend; | ||
| /// use burn_tensor::Tensor; | ||
| /// | ||
| /// fn example<B: Backend>() { | ||
| /// let device = Default::default(); | ||
| /// | ||
| /// let lhs = Tensor::<B, 1>::from_data([-2.0, 2.0, -2.0], &device); | ||
| /// let rhs = Tensor::<B, 1>::from_data([1.0, -1.0, -1.0], &device); | ||
| /// println!("{}", lhs.atan2(rhs)); // [-1.1071, 2.0344, -2.0344] | ||
| /// } | ||
| /// ``` | ||
| pub fn atan2(self, other: Self) -> Self { | ||
| Self::new(TensorPrimitive::Float(B::float_atan2( | ||
| self.primitive.tensor(), | ||
| other.primitive.tensor(), | ||
| ))) | ||
| } | ||
| /// Applies element wise round operation. | ||
@@ -683,7 +853,11 @@ /// | ||
| /// ``` | ||
| pub fn grid_sample_2d(self, grid: Tensor<B, D>, options: GridSampleOptions) -> Tensor<B, D> { | ||
| pub fn grid_sample_2d( | ||
| self, | ||
| grid: Tensor<B, D>, | ||
| options: impl Into<GridSampleOptions>, | ||
| ) -> Tensor<B, D> { | ||
| Tensor::new(TensorPrimitive::Float(B::float_grid_sample_2d( | ||
| self.primitive.tensor(), | ||
| grid.primitive.tensor(), | ||
| options, | ||
| options.into(), | ||
| ))) | ||
@@ -690,0 +864,0 @@ } |
@@ -11,2 +11,3 @@ pub(crate) mod check; | ||
| mod numeric; | ||
| mod options; | ||
| mod pad; | ||
@@ -22,4 +23,5 @@ mod take; | ||
| pub use numeric::*; | ||
| pub use options::*; | ||
| pub use transaction::*; | ||
| pub use burn_backend::tensor::IndexingUpdateOp; |
@@ -83,4 +83,4 @@ use alloc::vec::Vec; | ||
| /// ``` | ||
| pub fn pad(self, padding: (usize, usize, usize, usize), mode: PadMode) -> Self { | ||
| match mode { | ||
| pub fn pad(self, padding: (usize, usize, usize, usize), mode: impl Into<PadMode>) -> Self { | ||
| match mode.into() { | ||
| PadMode::Constant(value) => pad_constant(self, padding, value), | ||
@@ -87,0 +87,0 @@ PadMode::Reflect => pad_reflect(self, padding), |
@@ -72,3 +72,3 @@ use crate::{AsIndex, BasicOps, Int, Tensor, backend::Backend, check, check::TensorCheck}; | ||
| // Add all indices dimensions | ||
| for idx_dim in indices_dims { | ||
| for &idx_dim in indices_dims.iter() { | ||
| new_shape.push(idx_dim); | ||
@@ -75,0 +75,0 @@ } |
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display