Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement direct integer update #159

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions neet/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from abc import ABCMeta, abstractmethod
import networkx as nx
import six
from .python import long


@six.add_metaclass(ABCMeta)
Expand Down Expand Up @@ -65,11 +66,17 @@ def state_space(self):
def _unsafe_update(self, state, index, pin, values, *args, **kwargs):
pass

def update(self, state, index=None, pin=None, values=None, *args, **kwargs):
def _unsafe_update_int(self, state, index, pin, values, *args, **kwargs):
space = self.state_space()
decode = space.decode
encode = space._unsafe_encode

if state not in space:
raise ValueError("the provided state is not in the network's state space")
array_state = decode(state)
self._unsafe_update(array_state, index, pin, values, *args, **kwargs)
return encode(array_state)

def update(self, state, index=None, pin=None, values=None, *args, **kwargs):
space = self.state_space()

if index is not None:
if index < 0 or index >= self.size:
Expand All @@ -89,6 +96,13 @@ def update(self, state, index=None, pin=None, values=None, *args, **kwargs):
if val < 0 or val >= bases[key]:
raise ValueError("invalid state in values argument")

if isinstance(state, (int, long)):
if state < 0 or state >= space.volume:
raise ValueError("the provided state is not in the network's state space")
return self._unsafe_update_int(state, index, pin, values, *args, **kwargs)
else:
if state not in space:
raise ValueError("the provided state is not in the network's state space")
return self._unsafe_update(state, index, pin, values, *args, **kwargs)

@abstractmethod
Expand Down
99 changes: 99 additions & 0 deletions test/boolean/test_eca.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,3 +382,102 @@ def test_to_networkx_metadata(self):

self.assertEqual(nx_net.graph['code'], 30)
self.assertEqual(nx_net.graph['boundary'], (1, 0))


class TestECAIntegerUpdate(unittest.TestCase):
def test_invalid_state(self):
eca = ECA(30, 3)

with self.assertRaises(ValueError):
eca.update(-1)

with self.assertRaises(ValueError):
eca.update(9)

def test_update_closed(self):
eca = ECA(30, 1)

lattice = 0

self.assertEqual(0, eca.update(lattice))
self.assertEqual(0, lattice)

eca.size = 2
lattice = 0

self.assertEqual(0, eca.update(lattice))
self.assertEqual(0, lattice)

eca.size = 5
lattice = 4

self.assertEqual(14, eca.update(lattice))
self.assertEqual(4, lattice)

lattice = 14
self.assertEqual(19, eca.update(lattice))
self.assertEqual(14, lattice)

def test_update_open(self):
eca = ECA(30, 1, (0, 1))

lattice = 0
self.assertEqual(1, eca.update(lattice))
self.assertEqual(0, lattice)

eca.size = 2
lattice = 0
self.assertEqual(2, eca.update(lattice))
self.assertEqual(0, lattice)

eca.size = 5
lattice = 4
self.assertEqual(30, eca.update(lattice))
self.assertEqual(4, lattice)

lattice = 14
self.assertEqual(3, eca.update(lattice))
self.assertEqual(14, lattice)

def test_update_long_time_closed(self):
eca = ECA(45, 14)
lattice = 10571
expected = 5462

for n in range(1000):
lattice = eca.update(lattice)
self.assertEqual(expected, lattice)

def test_update_index(self):
eca = ECA(30, 5, (1, 1))

self.assertEqual(1, eca.update(0, index=0))
self.assertEqual(0, eca.update(0, index=1))
self.assertEqual(6, eca.update(4, index=1))

def test_update_pin_none(self):
eca = ECA(30, 5)

self.assertEqual(14, eca.update(4, pin=None))
self.assertEqual(19, eca.update(14, pin=[]))

def test_update_pin(self):
eca = ECA(30, 5)

self.assertEqual(12, eca.update(4, pin=[1]))
self.assertEqual(20, eca.update(12, pin=[1]))
self.assertEqual(21, eca.update(20, pin=[1]))

eca.boundary = (1, 1)
self.assertEqual(1, eca.update(0, pin=[-1]))
self.assertEqual(3, eca.update(1, pin=[0, -1]))

def test_update_values(self):
eca = ECA(30, 5)

self.assertEqual(10, eca.update(4, values={2: 0}))
self.assertEqual(17, eca.update(10, values={1: 0, 3: 0}))
self.assertEqual(10, eca.update(17, values={-1: 0}))

eca.boundary = (1, 1)
self.assertEqual(21, eca.update(0, values={2: 1}))
35 changes: 32 additions & 3 deletions test/boolean/test_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,35 @@ def test_to_networkx_metadata(self):

self.assertEqual(nx_net.graph['name'], net.metadata['name'])

# def test_draw(self):
# net = bnet.LogicNetwork([((0,), {'0'})])
# draw(net,labels='indices')

class TestLogicNetworkIntegerUpdate(unittest.TestCase):
def test_update(self):
net = LogicNetwork([((0,), {'0'})])
self.assertEqual(net.update(0, 0), 1)
self.assertEqual(net.update(1, 0), 0)
self.assertEqual(net.update(0), 1)
self.assertEqual(net.update(1), 0)

net = LogicNetwork([((1,), {'0', '1'}), ((0,), {'1'})])
self.assertEqual(net.update(0, 0), 1)
self.assertEqual(net.update(0, 1), 0)
self.assertEqual(net.update(0), 1)
self.assertEqual(net.update(2, 0), 3)
self.assertEqual(net.update(2, 1), 0)
self.assertEqual(net.update(2), 1)
self.assertEqual(net.update(1, 0), 1)
self.assertEqual(net.update(1, 1), 3)
self.assertEqual(net.update(1), 3)
self.assertEqual(net.update(3), 3)

net = LogicNetwork([((1, 2), {'01', '10'}),
((0, 2), {(0, 1), '10', (1, 1)}),
((0, 1), {'11'})])
self.assertEqual(net.update(2), 1)
self.assertEqual(net.update(7, 1), 7)
self.assertEqual(net.update(4), 3)
self.assertEqual(net.update(4, pin=[1]), 1)
self.assertEqual(net.update(4, pin=[0, 1]), 0)

self.assertEqual(net.update(4, values={0: 0}), 2)
self.assertEqual(net.update(4, pin=[1], values={0: 0}), 0)
92 changes: 92 additions & 0 deletions test/boolean/test_reca.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,3 +356,95 @@ def test_reca_values(self):
self.assertEqual([0, 0, 1, 0, 0], reca.update(xs, values={-2: 0}))
self.assertEqual([0, 1, 0, 1, 0], reca.update(
xs, values={2: 0, -5: 0}))


class TestRewiredECAIntegerUpdate(unittest.TestCase):
def test_reproduce_closed_ecas(self):
reca = RewiredECA(30, size=7)
eca = ECA(30, size=7)
state = 8
for _ in range(10):
expect = eca.update(state)
got = reca.update(state)
self.assertTrue(np.array_equal(expect, got))
state = expect

def test_reproduce_open_ecas(self):
reca = RewiredECA(30, boundary=(1, 0), size=7)
eca = ECA(30, size=7, boundary=(1, 0))
state = 8
for _ in range(10):
expect = eca.update(state)
got = reca.update(state)
self.assertTrue(np.array_equal(expect, got))
state = expect

def test_rewired_network(self):
reca = RewiredECA(30, wiring=[
[-1, 0, 1, 2, 3], [0, 1, 2, 3, 4], [1, 2, 3, 4, 5]
])
self.assertEqual(25, reca.update(16))

reca.wiring[:, :] = [
[0, 4, 1, 2, 3], [0, 1, 2, 3, 4], [0, 2, 3, 4, 5]
]
self.assertEqual(26, reca.update(16))
self.assertEqual(8, reca.update(26))
self.assertEqual(28, reca.update(8))
self.assertEqual(4, reca.update(28))
self.assertEqual(14, reca.update(4))
self.assertEqual(18, reca.update(14))
self.assertEqual(28, reca.update(18))

def test_reca_index(self):
reca = RewiredECA(30, wiring=[
[0, 4, 1, 2, 3], [0, 1, 2, 3, 4], [0, 2, 3, 4, 5]
])

self.assertEqual(24, reca.update(16, index=3))
self.assertEqual(28, reca.update(24, index=2))
self.assertEqual(14, reca.update(12, index=1))
self.assertEqual(10, reca.update(10, index=0))

def test_reca_pin_none(self):
reca = RewiredECA(30, size=5)

self.assertEqual(14, reca.update(4, pin=None))
self.assertEqual(19, reca.update(14, pin=[]))

def test_reca_pin(self):
reca = RewiredECA(30, wiring=[
[-1, 4, 1, 2, -1], [0, 1, 2, 3, 4], [0, 2, 3, 4, 5]
])

self.assertEqual(12, reca.update(4, pin=[1]))
self.assertEqual(14, reca.update(12, pin=[3]))
self.assertEqual(14, reca.update(14, pin=[3, 2]))
self.assertEqual(10, reca.update(14, pin=[-2]))

reca.boundary = (1, 1)
self.assertEqual(5, reca.update(4, pin=[1, 3]))
self.assertEqual(7, reca.update(5, pin=[-2, -5]))
self.assertEqual(15, reca.update(7, pin=[0, 2]))

def test_reca_values_none(self):
reca = RewiredECA(30, size=5)

self.assertEqual(14, reca.update(4, values=None))
self.assertEqual(19, reca.update(14, values={}))

def test_reca_values(self):
reca = RewiredECA(30, wiring=[
[-1, 4, 1, 2, -1], [0, 1, 2, 3, 4], [0, 2, 3, 4, 5]
])

self.assertEqual(15, reca.update(4, values={0: 1}))
self.assertEqual(3, reca.update(15, values={-1: 0}))
self.assertEqual(31, reca.update(3, values={-2: 1}))
self.assertEqual(5, reca.update(31, values={2: 1, -5: 1}))

reca.boundary = (1, 1)
self.assertEqual(14, reca.update(4, values={0: 0}))
self.assertEqual(19, reca.update(14, values={-1: 1}))
self.assertEqual(4, reca.update(19, values={-2: 0}))
self.assertEqual(10, reca.update(4, values={2: 0, -5: 0}))
Loading