Skip to content

Commit

Permalink
burn-autodiff: support no_std
Browse files Browse the repository at this point in the history
  • Loading branch information
ivila committed Feb 20, 2025
1 parent b89d674 commit d81c851
Show file tree
Hide file tree
Showing 35 changed files with 116 additions and 59 deletions.
15 changes: 12 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -167,5 +167,7 @@ cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features
### For xtask crate ###
tracel-xtask = { version = "=1.1.8" }

cfg_block = "0.2.0"

[profile.dev]
debug = 0 # Speed up compilation time and not necessary.
3 changes: 3 additions & 0 deletions crates/burn-autodiff/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", opt
derive-new = { workspace = true }
spin = { workspace = true }
log = { workspace = true }
cfg_block = { workspace = true }
hashbrown = { workspace = true }
num-traits = { workspace = true }

[dev-dependencies]
burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [
Expand Down
1 change: 1 addition & 0 deletions crates/burn-autodiff/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::{
runtime::AutodiffClient,
tensor::AutodiffTensor,
};
use alloc::{format, string::String};
use burn_tensor::{
backend::{AutodiffBackend, Backend},
ops::{BoolTensor, IntTensor, QuantizedTensor},
Expand Down
4 changes: 3 additions & 1 deletion crates/burn-autodiff/src/checkpoint/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ use super::{
retro_forward::RetroForwards,
state::{BackwardStates, State},
};
use crate::collections::HashMap;
use crate::graph::NodeID;
use std::collections::HashMap;

use alloc::{vec, vec::Vec};

#[derive(new, Debug)]
/// Links a [NodeID] to its autodiff graph [NodeRef]
Expand Down
4 changes: 3 additions & 1 deletion crates/burn-autodiff/src/checkpoint/builder.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use crate::{
collections::HashMap,
graph::{ComputingProperty, NodeID, NodeSteps},
tensor::AutodiffTensor,
};
use alloc::{boxed::Box, sync::Arc, vec::Vec};
use burn_tensor::backend::Backend;
use std::{any::Any, collections::HashMap, sync::Arc};
use core::any::Any;

use super::{
base::{Checkpointer, NodeTree},
Expand Down
4 changes: 3 additions & 1 deletion crates/burn-autodiff/src/checkpoint/retro_forward.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::collections::HashMap;
use crate::graph::NodeID;

use std::{collections::HashMap, fmt::Debug, sync::Arc};
use alloc::sync::Arc;
use core::fmt::Debug;

use super::state::{BackwardStates, State};

Expand Down
4 changes: 3 additions & 1 deletion crates/burn-autodiff/src/checkpoint/state.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::{any::Any, collections::HashMap};
use core::any::Any;

use crate::collections::HashMap;
use crate::graph::NodeID;
use alloc::boxed::Box;

/// In order to accept arbitrary node output in the same hashmap, we need to upcast them to any.
pub(crate) type StateContent = Box<dyn Any + Send>;
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-autodiff/src/checkpoint/strategy.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use core::fmt::Debug;
use std::sync::Arc;

use burn_tensor::backend::Backend;

use crate::{graph::ComputingProperty, tensor::AutodiffTensor};
use alloc::sync::Arc;

use super::{
builder::{ActionType, CheckpointerBuilder},
Expand Down
6 changes: 3 additions & 3 deletions crates/burn-autodiff/src/graph/base.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use super::NodeID;
use crate::{checkpoint::base::Checkpointer, grads::Gradients};
use std::collections::HashMap;
use crate::{checkpoint::base::Checkpointer, collections::HashMap, grads::Gradients};
use alloc::{boxed::Box, vec::Vec};

/// Backward step for reverse mode autodiff.
pub trait Step: Send + std::fmt::Debug {
pub trait Step: Send + core::fmt::Debug {
/// Executes the step and consumes it.
fn step(self: Box<Self>, grads: &mut Gradients, checkpointer: &mut Checkpointer);
/// Depth of the operation relative to the first node added to a graph.
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-autodiff/src/graph/node.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use alloc::{sync::Arc, vec::Vec};
use core::sync::atomic::{AtomicU64, Ordering};

use crate::checkpoint::retro_forward::RetroForward;
use crate::runtime::AutodiffClientImpl;
Expand Down
7 changes: 5 additions & 2 deletions crates/burn-autodiff/src/graph/traversal.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use super::{Step, StepBoxed};
use crate::NodeID;
use std::collections::{HashMap, HashSet};
use crate::{
collections::{HashMap, HashSet},
NodeID,
};
use alloc::vec::Vec;

/// Breadth for search algorithm.
pub struct BreadthFirstSearch;
Expand Down
13 changes: 13 additions & 0 deletions crates/burn-autodiff/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_auto_cfg))]

Expand Down Expand Up @@ -34,3 +35,15 @@ pub use backend::*;

#[cfg(feature = "export_tests")]
mod tests;

/// A facade around for HashMap and HashSet.
/// This avoids elaborate import wrangling having to happen in every module.
mod collections {
cfg_block::cfg_block! {
if #[cfg(feature = "std")] {
pub use std::collections::{HashMap, HashSet};
} else {
pub use hashbrown::{HashMap, HashSet};
}
}
}
2 changes: 1 addition & 1 deletion crates/burn-autodiff/src/ops/activation.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::marker::PhantomData;
use core::marker::PhantomData;

use crate::{
checkpoint::{
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-autodiff/src/ops/backward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ use burn_tensor::backend::Backend;
/// Concrete types implementing this trait should not have any state.
/// If a state is necessary during the backward pass,
/// they should be declared with the associated type 'State'.
pub trait Backward<B, const N: usize>: Send + std::fmt::Debug
pub trait Backward<B, const N: usize>: Send + core::fmt::Debug
where
Self: Sized + 'static,
B: Backend,
{
/// Associated type to compute the backward pass.
type State: Clone + Send + std::fmt::Debug + 'static;
type State: Clone + Send + core::fmt::Debug + 'static;

/// The backward pass.
fn backward(
Expand Down
13 changes: 7 additions & 6 deletions crates/burn-autodiff/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ use crate::{
graph::{ComputingProperty, NodeID, NodeRef, Requirement, Step},
tensor::AutodiffTensor,
};
use alloc::{boxed::Box, vec::Vec};
use burn_tensor::{backend::Backend, ops::FloatTensor, Shape, TensorMetadata};
use std::marker::PhantomData;
use core::marker::PhantomData;

/// Operation in preparation.
///
Expand Down Expand Up @@ -134,7 +135,7 @@ where
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, ComputePropertyDone>
where
B: Backend,
S: Clone + Send + std::fmt::Debug + 'static,
S: Clone + Send + core::fmt::Debug + 'static,
BO: Backward<B, N, State = S>,
{
/// Prepare an operation that requires a state during the backward pass.
Expand All @@ -161,7 +162,7 @@ where
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, UnTracked>
where
B: Backend,
S: Clone + Send + std::fmt::Debug + 'static,
S: Clone + Send + core::fmt::Debug + 'static,
BO: Backward<B, N, State = S>,
{
/// Finish the preparation of an untracked operation and returns the output tensor.
Expand All @@ -184,7 +185,7 @@ where
impl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, Tracked>
where
B: Backend,
S: Clone + Send + std::fmt::Debug + 'static,
S: Clone + Send + core::fmt::Debug + 'static,
BO: Backward<B, N, State = S>,
{
/// Finish the preparation of a tracked operation and returns the output tensor.
Expand Down Expand Up @@ -235,7 +236,7 @@ struct OpsStep<B, T, SB, const N: usize>
where
B: Backend,
T: Backward<B, N, State = SB>,
SB: Clone + Send + std::fmt::Debug + 'static,
SB: Clone + Send + core::fmt::Debug + 'static,
{
ops: Ops<SB, N>,
backward: T,
Expand All @@ -246,7 +247,7 @@ impl<B, T, SB, const N: usize> Step for OpsStep<B, T, SB, N>
where
B: Backend,
T: Backward<B, N, State = SB>,
SB: Clone + Send + std::fmt::Debug + 'static,
SB: Clone + Send + core::fmt::Debug + 'static,
{
fn step(self: Box<Self>, grads: &mut Gradients, checkpointer: &mut Checkpointer) {
self.backward.backward(self.ops, grads, checkpointer);
Expand Down
5 changes: 3 additions & 2 deletions crates/burn-autodiff/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor, Autodiff};
use alloc::vec::Vec;

use burn_tensor::{
backend::Backend,
Expand Down Expand Up @@ -31,7 +32,7 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
B::bool_reshape(tensor, shape)
}

fn bool_slice(tensor: BoolTensor<B>, ranges: &[std::ops::Range<usize>]) -> BoolTensor<B> {
fn bool_slice(tensor: BoolTensor<B>, ranges: &[core::ops::Range<usize>]) -> BoolTensor<B> {
B::bool_slice(tensor, ranges)
}

Expand All @@ -41,7 +42,7 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {

fn bool_slice_assign(
tensor: BoolTensor<Self>,
ranges: &[std::ops::Range<usize>],
ranges: &[core::ops::Range<usize>],
value: BoolTensor<Self>,
) -> BoolTensor<Self> {
B::bool_slice_assign(tensor, ranges, value)
Expand Down
7 changes: 4 additions & 3 deletions crates/burn-autodiff/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor, Autodiff};
use alloc::vec::Vec;

use burn_tensor::{
backend::Backend,
Expand Down Expand Up @@ -27,7 +28,7 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
B::int_reshape(tensor, shape)
}

fn int_slice(tensor: IntTensor<B>, ranges: &[std::ops::Range<usize>]) -> IntTensor<B> {
fn int_slice(tensor: IntTensor<B>, ranges: &[core::ops::Range<usize>]) -> IntTensor<B> {
B::int_slice(tensor, ranges)
}

Expand All @@ -37,7 +38,7 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {

fn int_slice_assign(
tensor: IntTensor<B>,
ranges: &[std::ops::Range<usize>],
ranges: &[core::ops::Range<usize>],
value: IntTensor<B>,
) -> IntTensor<B> {
B::int_slice_assign(tensor, ranges, value)
Expand Down Expand Up @@ -305,7 +306,7 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
B::int_random(shape, distribution, device)
}

fn int_arange(range: std::ops::Range<i64>, device: &Device<Self>) -> IntTensor<Self> {
fn int_arange(range: core::ops::Range<i64>, device: &Device<Self>) -> IntTensor<Self> {
B::int_arange(range, device)
}

Expand Down
2 changes: 1 addition & 1 deletion crates/burn-autodiff/src/ops/qtensor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ops::Range;
use core::ops::Range;

use burn_tensor::{
backend::Backend,
Expand Down
Loading

0 comments on commit d81c851

Please sign in to comment.