Skip to content

Commit

Permalink
burn-autodiff: support no-std
Browse files Browse the repository at this point in the history
  • Loading branch information
ivila committed Feb 24, 2025
1 parent a1e7912 commit c9e5898
Show file tree
Hide file tree
Showing 34 changed files with 267 additions and 178 deletions.
190 changes: 95 additions & 95 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ tch = "0.15.0"

ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }
portable-atomic = { version = "1.10.0" }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b41cbd82d53f091e76f56cad58c277fe2481c48e" }
Expand Down
3 changes: 3 additions & 0 deletions crates/burn-autodiff/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", opt
derive-new = { workspace = true }
spin = { workspace = true }
log = { workspace = true }
hashbrown = { workspace = true }
num-traits = { workspace = true }
portable-atomic = { workspace = true }

[dev-dependencies]
burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [
Expand Down
1 change: 1 addition & 0 deletions crates/burn-autodiff/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::{
runtime::AutodiffClient,
tensor::AutodiffTensor,
};
use alloc::{format, string::String};
use burn_tensor::{
backend::{AutodiffBackend, Backend},
ops::{BoolTensor, IntTensor, QuantizedTensor},
Expand Down
4 changes: 3 additions & 1 deletion crates/burn-autodiff/src/checkpoint/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ use super::{
retro_forward::RetroForwards,
state::{BackwardStates, State},
};
use crate::collections::HashMap;
use crate::graph::NodeID;
use std::collections::HashMap;

use alloc::{vec, vec::Vec};

#[derive(new, Debug)]
/// Links a [NodeID] to its autodiff graph [NodeRef]
Expand Down
4 changes: 3 additions & 1 deletion crates/burn-autodiff/src/checkpoint/builder.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use crate::{
collections::HashMap,
graph::{ComputingProperty, NodeID, NodeSteps},
tensor::AutodiffTensor,
};
use alloc::{boxed::Box, sync::Arc, vec::Vec};
use burn_tensor::backend::Backend;
use std::{any::Any, collections::HashMap, sync::Arc};
use core::any::Any;

use super::{
base::{Checkpointer, NodeTree},
Expand Down
4 changes: 3 additions & 1 deletion crates/burn-autodiff/src/checkpoint/retro_forward.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::collections::HashMap;
use crate::graph::NodeID;

use std::{collections::HashMap, fmt::Debug, sync::Arc};
use alloc::sync::Arc;
use core::fmt::Debug;

use super::state::{BackwardStates, State};

Expand Down
4 changes: 3 additions & 1 deletion crates/burn-autodiff/src/checkpoint/state.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::{any::Any, collections::HashMap};
use core::any::Any;

use crate::collections::HashMap;
use crate::graph::NodeID;
use alloc::boxed::Box;

/// In order to accept arbitrary node output in the same hashmap, we need to upcast them to any.
pub(crate) type StateContent = Box<dyn Any + Send>;
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-autodiff/src/checkpoint/strategy.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use core::fmt::Debug;
use std::sync::Arc;

use burn_tensor::backend::Backend;

use crate::{graph::ComputingProperty, tensor::AutodiffTensor};
use alloc::sync::Arc;

use super::{
builder::{ActionType, CheckpointerBuilder},
Expand Down
6 changes: 3 additions & 3 deletions crates/burn-autodiff/src/graph/base.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use super::NodeID;
use crate::{checkpoint::base::Checkpointer, grads::Gradients};
use std::collections::HashMap;
use crate::{checkpoint::base::Checkpointer, collections::HashMap, grads::Gradients};
use alloc::{boxed::Box, vec::Vec};

/// Backward step for reverse mode autodiff.
pub trait Step: Send + std::fmt::Debug {
pub trait Step: Send + core::fmt::Debug {
/// Executes the step and consumes it.
fn step(self: Box<Self>, grads: &mut Gradients, checkpointer: &mut Checkpointer);
/// Depth of the operation relative to the first node added to a graph.
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-autodiff/src/graph/node.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use alloc::{sync::Arc, vec::Vec};
use portable_atomic::{AtomicU64, Ordering};

use crate::checkpoint::retro_forward::RetroForward;
use crate::runtime::AutodiffClientImpl;
Expand Down
7 changes: 5 additions & 2 deletions crates/burn-autodiff/src/graph/traversal.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use super::{Step, StepBoxed};
use crate::NodeID;
use std::collections::{HashMap, HashSet};
use crate::{
collections::{HashMap, HashSet},
NodeID,
};
use alloc::vec::Vec;

/// Breadth for search algorithm.
pub struct BreadthFirstSearch;
Expand Down
10 changes: 10 additions & 0 deletions crates/burn-autodiff/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_auto_cfg))]

Expand Down Expand Up @@ -34,3 +35,12 @@ pub use backend::*;

#[cfg(feature = "export_tests")]
mod tests;

/// A facade around for HashMap and HashSet.
/// This avoids elaborate import wrangling having to happen in every module.
mod collections {
#[cfg(not(feature = "std"))]
pub use hashbrown::{HashMap, HashSet};
#[cfg(feature = "std")]
pub use std::collections::{HashMap, HashSet};
}
2 changes: 1 addition & 1 deletion crates/burn-autodiff/src/ops/activation.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::marker::PhantomData;
use core::marker::PhantomData;

use crate::{
checkpoint::{
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-autodiff/src/ops/backward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ use burn_tensor::backend::Backend;
/// Concrete types implementing this trait should not have any state.
/// If a state is necessary during the backward pass,
/// they should be declared with the associated type 'State'.
pub trait Backward<B, const N: usize>: Send + std::fmt::Debug
pub trait Backward<B, const N: usize>: Send + core::fmt::Debug
where
Self: Sized + 'static,
B: Backend,
{
/// Associated type to compute the backward pass.
type State: Clone + Send + std::fmt::Debug + 'static;
type State: Clone + Send + core::fmt::Debug + 'static;

/// The backward pass.
fn backward(
Expand Down
13 changes: 7 additions & 6 deletions crates/burn-autodiff/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ use crate::{
graph::{ComputingProperty, NodeID, NodeRef, Requirement, Step},
tensor::AutodiffTensor,
};
use alloc::{boxed::Box, vec::Vec};
use burn_tensor::{backend::Backend, ops::FloatTensor, Shape, TensorMetadata};
use std::marker::PhantomData;
use core::marker::PhantomData;

/// Operation in preparation.
///
Expand Down Expand Up @@ -134,7 +135,7 @@ where
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, ComputePropertyDone>
where
B: Backend,
S: Clone + Send + std::fmt::Debug + 'static,
S: Clone + Send + core::fmt::Debug + 'static,
BO: Backward<B, N, State = S>,
{
/// Prepare an operation that requires a state during the backward pass.
Expand All @@ -161,7 +162,7 @@ where
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, UnTracked>
where
B: Backend,
S: Clone + Send + std::fmt::Debug + 'static,
S: Clone + Send + core::fmt::Debug + 'static,
BO: Backward<B, N, State = S>,
{
/// Finish the preparation of an untracked operation and returns the output tensor.
Expand All @@ -184,7 +185,7 @@ where
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, Tracked>
where
B: Backend,
S: Clone + Send + std::fmt::Debug + 'static,
S: Clone + Send + core::fmt::Debug + 'static,
BO: Backward<B, N, State = S>,
{
/// Finish the preparation of a tracked operation and returns the output tensor.
Expand Down Expand Up @@ -235,7 +236,7 @@ struct OpsStep<B, T, SB, const N: usize>
where
B: Backend,
T: Backward<B, N, State = SB>,
SB: Clone + Send + std::fmt::Debug + 'static,
SB: Clone + Send + core::fmt::Debug + 'static,
{
ops: Ops<SB, N>,
backward: T,
Expand All @@ -246,7 +247,7 @@ impl<B, T, SB, const N: usize> Step for OpsStep<B, T, SB, N>
where
B: Backend,
T: Backward<B, N, State = SB>,
SB: Clone + Send + std::fmt::Debug + 'static,
SB: Clone + Send + core::fmt::Debug + 'static,
{
fn step(self: Box<Self>, grads: &mut Gradients, checkpointer: &mut Checkpointer) {
self.backward.backward(self.ops, grads, checkpointer);
Expand Down
5 changes: 3 additions & 2 deletions crates/burn-autodiff/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor, Autodiff};
use alloc::vec::Vec;

use burn_tensor::{
backend::Backend,
Expand Down Expand Up @@ -31,7 +32,7 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
B::bool_reshape(tensor, shape)
}

fn bool_slice(tensor: BoolTensor<B>, ranges: &[std::ops::Range<usize>]) -> BoolTensor<B> {
fn bool_slice(tensor: BoolTensor<B>, ranges: &[core::ops::Range<usize>]) -> BoolTensor<B> {
B::bool_slice(tensor, ranges)
}

Expand All @@ -41,7 +42,7 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {

fn bool_slice_assign(
tensor: BoolTensor<Self>,
ranges: &[std::ops::Range<usize>],
ranges: &[core::ops::Range<usize>],
value: BoolTensor<Self>,
) -> BoolTensor<Self> {
B::bool_slice_assign(tensor, ranges, value)
Expand Down
7 changes: 4 additions & 3 deletions crates/burn-autodiff/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor, Autodiff};
use alloc::vec::Vec;

use burn_tensor::{
backend::Backend,
Expand Down Expand Up @@ -27,7 +28,7 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
B::int_reshape(tensor, shape)
}

fn int_slice(tensor: IntTensor<B>, ranges: &[std::ops::Range<usize>]) -> IntTensor<B> {
fn int_slice(tensor: IntTensor<B>, ranges: &[core::ops::Range<usize>]) -> IntTensor<B> {
B::int_slice(tensor, ranges)
}

Expand All @@ -37,7 +38,7 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {

fn int_slice_assign(
tensor: IntTensor<B>,
ranges: &[std::ops::Range<usize>],
ranges: &[core::ops::Range<usize>],
value: IntTensor<B>,
) -> IntTensor<B> {
B::int_slice_assign(tensor, ranges, value)
Expand Down Expand Up @@ -305,7 +306,7 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
B::int_random(shape, distribution, device)
}

fn int_arange(range: std::ops::Range<i64>, device: &Device<Self>) -> IntTensor<Self> {
fn int_arange(range: core::ops::Range<i64>, device: &Device<Self>) -> IntTensor<Self> {
B::int_arange(range, device)
}

Expand Down
2 changes: 1 addition & 1 deletion crates/burn-autodiff/src/ops/qtensor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ops::Range;
use core::ops::Range;

use burn_tensor::{
backend::Backend,
Expand Down
22 changes: 13 additions & 9 deletions crates/burn-autodiff/src/ops/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use alloc::{vec, vec::Vec};
use std::marker::PhantomData;
use alloc::{boxed::Box, vec, vec::Vec};
use core::marker::PhantomData;

#[cfg(not(feature = "std"))]
#[allow(unused_imports, reason = "required on aarch64, unused on x86_64")]
use num_traits::float::Float;

use crate::{
checkpoint::{
Expand Down Expand Up @@ -1127,15 +1131,15 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>

fn float_slice(
tensor: FloatTensor<Self>,
ranges: &[std::ops::Range<usize>],
ranges: &[core::ops::Range<usize>],
) -> FloatTensor<Self> {
#[derive(Debug)]
struct Index;

#[derive(new, Debug)]
struct RetroSlice<B: Backend> {
tensor_id: NodeID,
ranges: Vec<std::ops::Range<usize>>,
ranges: Vec<core::ops::Range<usize>>,
_backend: PhantomData<B>,
}

Expand All @@ -1148,7 +1152,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}

impl<B: Backend> Backward<B, 1> for Index {
type State = (Vec<std::ops::Range<usize>>, Shape, B::Device);
type State = (Vec<core::ops::Range<usize>>, Shape, B::Device);

fn backward(
self,
Expand Down Expand Up @@ -1186,7 +1190,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>

fn float_slice_assign(
tensor: FloatTensor<Self>,
ranges: &[std::ops::Range<usize>],
ranges: &[core::ops::Range<usize>],
value: FloatTensor<Self>,
) -> FloatTensor<Self> {
#[derive(Debug)]
Expand All @@ -1195,7 +1199,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
#[derive(new, Debug)]
struct RetroSliceAssign<B: Backend> {
tensor_id: NodeID,
ranges: Vec<std::ops::Range<usize>>,
ranges: Vec<core::ops::Range<usize>>,
value_id: NodeID,
_backend: PhantomData<B>,
}
Expand All @@ -1210,7 +1214,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}

impl<B: Backend> Backward<B, 2> for SliceAssign {
type State = (Vec<std::ops::Range<usize>>, Shape, B::Device);
type State = (Vec<core::ops::Range<usize>>, Shape, B::Device);

fn backward(
self,
Expand Down Expand Up @@ -2066,7 +2070,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
let ops = checkpointer.retrieve_node_output(ops.state);
let exponent = B::float_neg(B::float_powf_scalar(ops, 2.0));
let numerator = B::float_mul_scalar(B::float_exp(exponent), 2.0.elem());
let denominator = std::f64::consts::PI.sqrt().elem();
let denominator = core::f64::consts::PI.sqrt().elem();
let value = B::float_div_scalar(numerator, denominator);

B::float_mul(grad, value)
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-autodiff/src/ops/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{checkpoint::strategy::CheckpointStrategy, Autodiff};
impl<B: Backend, C: CheckpointStrategy> TransactionOps<Self> for Autodiff<B, C> {
fn tr_execute(
transaction: TransactionPrimitive<Self>,
) -> impl std::future::Future<Output = burn_tensor::ops::TransactionPrimitiveResult> + 'static + Send
) -> impl core::future::Future<Output = burn_tensor::ops::TransactionPrimitiveResult> + 'static + Send
{
B::tr_execute(TransactionPrimitive {
read_floats: transaction
Expand Down
14 changes: 9 additions & 5 deletions crates/burn-autodiff/src/runtime/memory_management.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::{tensor::NodeRefCount, NodeID};
use std::{
use crate::{
collections::{HashMap, HashSet},
mem,
sync::Arc,
tensor::NodeRefCount,
NodeID,
};
use alloc::{borrow::ToOwned, sync::Arc, vec, vec::Vec};
use core::mem;

#[derive(Default, Debug)]
pub struct GraphMemoryManagement {
Expand Down Expand Up @@ -82,7 +83,10 @@ impl GraphMemoryManagement {

fn clear_unused_roots(&mut self, to_delete: &mut Vec<NodeID>) {
for (id, parents) in self.nodes.iter() {
let is_useful = matches!(self.statuses.get(id), Some(NodeMemoryStatus::Useful));
let is_useful = matches!(
self.statuses.get(id.as_ref()),
Some(NodeMemoryStatus::Useful)
);

// Check if parents are either empty or absent from self.nodes
let parents_absent = parents.iter().all(|p| !self.nodes.contains_key(p));
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-autodiff/src/runtime/mutex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use burn_tensor::backend::Backend;
pub struct MutexClient;

impl core::fmt::Debug for MutexClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("MutexClient")
}
}
Expand Down
Loading

0 comments on commit c9e5898

Please sign in to comment.