Skip to content

Commit

Permalink
Adds static tape functions to control tape activation
Browse files Browse the repository at this point in the history
  • Loading branch information
auto-differentiation-dev committed Mar 28, 2024
1 parent 41af595 commit 5f6a608
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 7 deletions.
13 changes: 13 additions & 0 deletions docs/ref/tape.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,19 @@ or `!c++ nullptr` if no active tape has been set.
Note that this is a thread-local pointer - calling this function in different
threads gives different results.

#### `setActive`

`#!c++ static void setActive(Tape* t)` static function that sets the given tape as the
globally active one. This is equivalent to `t.activate()`.

It may throw [`TapeAlreadyActive`](exceptions.md) if another tape is
already active for the current thread.

#### `deactivateAll`

`#!c++ static void deactivateAll()` deactivates any currently active tapes.
Equivalent to `auto t = Tape::getActive(); if (t) t->deactivate();`.

#### `registerInput`

`#!c++ void registerInput(active_type& inp)` registers the given variable with the tape and start recording dependents of it. A call to this function or its overloads is required in order to calculate adjoints.
Expand Down
19 changes: 12 additions & 7 deletions src/XAD/Tape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,8 @@ class Tape
Tape& operator=(const Tape&) = delete;

// recording control
XAD_INLINE void activate()
{
if (active_tape_ != nullptr)
throw TapeAlreadyActive();
else
active_tape_ = this;
}
XAD_INLINE void activate() { setActive(this); }

XAD_INLINE void deactivate()
{
if (active_tape_ == this)
Expand All @@ -106,6 +101,16 @@ class Tape
XAD_INLINE bool isActive() const { return active_tape_ == this; }
XAD_INLINE static Tape* getActive() { return active_tape_; }

XAD_INLINE static void setActive(Tape* t)
{
if (active_tape_ != nullptr)
throw TapeAlreadyActive();
else
active_tape_ = t;
}

XAD_INLINE static void deactivateAll() { active_tape_ = nullptr; }

XAD_INLINE void registerInput(active_type& inp)
{
if (!inp.shouldRecord()) // already registered
Expand Down
27 changes: 27 additions & 0 deletions test/Tape_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,33 @@ TEST(Tape, canInitializeDeactivated)
EXPECT_NE(nullptr, Tape<float>::getActive());
}

TEST(Tape, canActivateStatically)
{
using xad::Tape;
Tape<float> s(false);

EXPECT_FALSE(s.isActive());
EXPECT_EQ(nullptr, Tape<float>::getActive());

xad::Tape<float>::setActive(&s);

EXPECT_TRUE(s.isActive());
EXPECT_NE(nullptr, Tape<float>::getActive());
}

TEST(Tape, canDeactivateGlobally)
{
using xad::Tape;

EXPECT_EQ(nullptr, Tape<double>::getActive());

Tape<double> s;

EXPECT_TRUE(s.isActive());
Tape<double>::deactivateAll();
EXPECT_FALSE(s.isActive());
}

TEST(Tape, isMovable)
{
xad::Tape<double> s(false);
Expand Down

0 comments on commit 5f6a608

Please sign in to comment.