Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: 'numpy.ndarray' object has no attribute 'at' #17

Open
CLL112 opened this issue Feb 13, 2025 · 10 comments
Open

AttributeError: 'numpy.ndarray' object has no attribute 'at' #17

CLL112 opened this issue Feb 13, 2025 · 10 comments

Comments

@CLL112
Copy link

CLL112 commented Feb 13, 2025

I installed Python 3.10 and followed the tutorial for installation. When running examples/func_fit/xor_hyperneat.py, I encountered the following error:

Traceback (most recent call last):
  File "/home/longchen/huawei/tensorneat/examples/func_fit/xor_hyperneat.py", line 41, in <module>
    state, best = pipeline.auto_run(state)
  File "/home/longchen/huawei/tensorneat/src/tensorneat/pipeline.py", line 101, in auto_run
    compiled_step = jax.jit(self.step).lower(state).compile()
  File "/home/longchen/huawei/tensorneat/src/tensorneat/pipeline.py", line 82, in step
    pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0))(
  File "/home/longchen/huawei/tensorneat/src/tensorneat/algorithm/hyperneat/hyperneat.py", line 79, in transform
    ), self.substrate.make_conns(query_res)
  File "/home/longchen/huawei/tensorneat/src/tensorneat/algorithm/hyperneat/substrate/default.py", line 25, in make_conns
    return self.conns.at[:, -1].set(query_res)
AttributeError: 'numpy.ndarray' object has no attribute 'at'

It seems like the issue is related to the version of NumPy, but I have tried both 1.26.2 and 2.0.1, and I still get the error. Below are the versions of some of the packages in my environment:

flax                     0.10.3
gym                      0.26.2
gym-notices              0.0.8
gymnasium                1.0.0
gymnax                   0.0.8
humanize                 4.11.0
jax                      0.5.0
jax-cuda12-pjrt          0.5.0
jax-cuda12-plugin        0.5.0
jaxlib                   0.5.0
jaxopt                   0.8.3
Jinja2                   3.1.5
kiwisolver               1.4.8
nest-asyncio             1.6.0
numpy                    2.0.1
pandas                   2.2.3
pillow                   11.1.0
Pygments                 2.19.1
PyOpenGL                 3.1.9
pyparsing                3.2.1
python-dateutil          2.9.0.post0
pytinyrenderer           0.0.14
tensorboardX             2.6.2.2
tensorneat               0.1.0      
tensorstore              0.1.71
@CLL112
Copy link
Author

CLL112 commented Feb 13, 2025

The query at seems to be a feature of jax.numpy.ndarray, not the regular numpy.ndarray. However, modifying the code as follows:

import jax.numpy as jnp

self.coors = jnp.array(coors)  # jnp instead of np
self.nodes = jnp.array(nodes)  # jnp instead of np
self.conns = jnp.array(conns)  # jnp instead of np

def make_conns(self, query_res):
    return self.conns.at[:, -1].set(query_res)

Still results in an error:

Traceback (most recent call last):
  File "/home/longchen/huawei/tensorneat/examples/func_fit/xor_hyperneat.py", line 41, in <module>
    state, best = pipeline.auto_run(state)
  File "/home/longchen/huawei/tensorneat/src/tensorneat/pipeline.py", line 101, in auto_run
    compiled_step = jax.jit(self.step).lower(state).compile()
  File "/home/longchen/huawei/tensorneat/src/tensorneat/pipeline.py", line 82, in step
    pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0))(
  File "/home/longchen/huawei/tensorneat/src/tensorneat/algorithm/hyperneat/hyperneat.py", line 79, in transform
    ), self.substrate.make_conns(query_res)
  File "/home/longchen/huawei/tensorneat/src/tensorneat/algorithm/hyperneat/substrate/default.py", line 30, in make_conns
    return self.conns.at[:, -1].set(query_res)
  File "/home/longchen/miniconda3/envs/neat/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 799, in set
    return scatter._scatter_update(self.array, self.index, values, lax.scatter,
  File "/home/longchen/miniconda3/envs/neat/lib/python3.10/site-packages/jax/_src/ops/scatter.py", line 76, in _scatter_update
    return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
  File "/home/longchen/miniconda3/envs/neat/lib/python3.10/site-packages/jax/_src/ops/scatter.py", line 111, in _scatter_impl
    y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
  File "/home/longchen/miniconda3/envs/neat/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3263, in broadcast_to
    return util._broadcast_to(array, shape)
  File "/home/longchen/miniconda3/envs/neat/lib/python3.10/site-packages/jax/_src/numpy/util.py", line 238, in _broadcast_to
    raise ValueError(f"Cannot broadcast to shape with fewer dimensions: {arr_shape=} {shape=}")
ValueError: Cannot broadcast to shape with fewer dimensions: arr_shape=(24, 1) shape=(24,)

WLS2002 added a commit that referenced this issue Feb 13, 2025
@WLS2002
Copy link
Collaborator

WLS2002 commented Feb 13, 2025

This is quite strange. I suspect this issue was caused by a JAX update. I have now fixed the problem.

Thank you very much for using TensorNEAT and for reporting this issue! If you encounter any other problems, I’d be happy to help.

@CLL112
Copy link
Author

CLL112 commented Feb 13, 2025

This is quite strange. I suspect this issue was caused by a JAX update. I have now fixed the problem.

Thank you very much for using TensorNEAT and for reporting this issue! If you encounter any other problems, I’d be happy to help.

Thank you very much for your work. Additionally, I would like to ask: I am planning to use HyperNEAT for data prediction, where the input consists of 4 features, the output consists of 2 predicted values, and there are a total of 100,000 data samples. Would it be sufficient to modify the data code in examples/func_fit/xor_hyperneat.py as follows?

problem=custom(),

@WLS2002
Copy link
Collaborator

WLS2002 commented Feb 14, 2025

It seems like you want to use HyperNEAT to solve a problem which requires 4 inputs and 2 outputs. In that case, you also need to modify the Substrate settings in xor_hyperneat.py, as xor_hyperneat.py is solving xor3d problem with 3 inputs and 1 output.

You need to make the your substrate have 5 (4 + 1, 4 is for inputs and extra 1 is for bias) input_coors and 2 output_coors.

You can use MLPSubstrate to create a MLP structure and automaticly generate the coordinates:

from tensorneat.algorithm.hyperneat import MLPSubstrate

substrate=MLPSubstrate(
    layers=[5, 5, 5, 5, 2], coor_range=(-5.0, 5.0, -5.0, 5.0)
)

Or use FullSubstrate to configure coordinates by hand:

from tensorneat.algorithm.hyperneat import FullSubstrate

substrate=FullSubstrate(
    input_coors=((-1, -1), (-0.5, -1), (0, -1), (0.5, -1), (1, -1)),
    hidden_coors=((-1, 0), (0, 0), (1, 0)),
    output_coors=((-0.5, 1), (0.5, 1)),
)

P.S. You may need to adjust the concrete structure of MLP or coordinates to fit you problem.

Then, you need to modify the problem that you have already done:

problem=custom(),

Hope this can help you!

@CLL112
Copy link
Author

CLL112 commented Feb 14, 2025

Hi, Thanks for the detailed reply. I have encountered two more issues:

  1. I have 100,000 data points, but I’ve found that using only 3,000 data points occupies about 18GB of VRAM. When I add more data, I get an out of memory error. Is it possible to divide the data into batches and perform multiple forward passes to compute the loss for all the data?

  2. The generated network is using too few nodes. I would like to use more neurons to achieve better results. I set:

MLPSubstrate(
                layers=[5, 10, 20, 15, 2], coor_range=(-5.0, 5.0, -5.0, 5.0)
            ),

However, after 100 generations, the maximum number of nodes used is only 18. How can I generate a more complex network with more nodes? Here's the output:

Generation: 100, Cost time: 2020.43ms
fitness: valid cnt: 10000, max: -0.0015, min: -1.1272, mean: -0.2691, std: 0.3698

node counts: max: 18, min: 7, mean: 10.87
conn counts: max: 17, min: 2, mean: 10.34
species: 19, [1124, 989, 605, 453, 854, 795, 737, 632, 645, 537, 490, 424, 362, 263, 177, 220, 164, 103, 426]

@WLS2002
Copy link
Collaborator

WLS2002 commented Feb 14, 2025

  1. Yes, you can divide the data into batches to reduce memory cost. What problem are you using now? Show me your code and let me help you implement this.
  2. "The maximum number of nodes used" is the node in NEAT, in other words, CPPN. HyperNEAT use a small network (CPPN) to generate the weights of a large network (MLP that you designed, with strcture [5, 10, 20, 15, 2]). So more nodes in CPPN does not generate a more complex network. If you want to obtain a larger CPPN, you can modify mutation configs or initialize NEAT with a larger structure:
from tensorneat.genome import DefaultMutation
neat = NEAT(
    pop_size=10000,
    species_size=20,
    survival_threshold=0.01,
    genome=DefaultGenome(
        num_inputs=4,
        num_outputs=1,
        init_hidden_layers=(10,),  # default is (), which means no hidden nodes
        mutation=DefaultMutation(
            node_add=0.2,  # default is 0.1 
            conn_add=0.4,  # default is 0.2
        ),
    ),
)

@CLL112
Copy link
Author

CLL112 commented Feb 14, 2025

Thank you for your reply. Does the framework support batch inference? Currently, I manually split the data into multiple batches for inference. Is there a simpler way to do this?

class custom(FuncFit):
    def __init__(self, filename=""):
        super().__init__()
        data = scipy.io.loadmat("data.mat")
        ...
        self.x = jnp.array(x)   # (100000,4)
        self.y = jnp.array(y)   # (100000,2)

    @property
    def inputs(self):
        return self.x

    @property
    def targets(self):
        return self.y
  
    @property
    def input_shape(self):
        return self.x.shape

    @property
    def output_shape(self):
        return self.x.shape

    def evaluate(self, state, randkey, act_func, params):
        batch_size = 1000  
        num_batches = self.inputs.shape[0] // batch_size  
        input_batches = jnp.array_split(self.inputs, num_batches, axis=0)
  
        result = []
        for batch in input_batches:
            predict = vmap(act_func, in_axes=(None, None, 0))(
                    state, params, batch
                )
            result.append(predict)
  
        predict = jnp.concatenate(result, axis=0)
        loss = jnp.mean((predict - self.targets) ** 2)
        return -loss

@WLS2002
Copy link
Collaborator

WLS2002 commented Feb 14, 2025

It seems that you have already done this. TensorNEAT support batch inference in both population and datapoint demension. I think your code is simple enough for batch evaluating. Hope it works!

P.s. For loop in jax may causes large compile time, you can use loop in JAX such as jax.lax.fori_loop to avoid the issues.

@CLL112
Copy link
Author

CLL112 commented Feb 14, 2025

Thank you for your reminder. After using jax.lax.fori_loop, the speed improved a lot. I still have a few doubts. I used NEAT, HyperNEAT, and HyperNEATFeedForward, and found that NEAT performed the best and converged very quickly. The loss decreased from 0.01 to 0.0018, while the performance of the HyperNEAT and HyperNEATFeedForward algorithms was very limited. The loss decreased from 0.0586 to 0.0485 and then stopped. I thought HyperNEAT and HyperNEATFeedForward were superior to NEAT, so this is quite strange. Below are the codes for the three algorithms. Could it be that I made an incorrect setting?
NEAT:

pipeline = Pipeline(
        algorithm=NEAT(
            pop_size=10000,
            species_size=20,
            survival_threshold=0.01,
            genome=DefaultGenome(
                num_inputs=4,
                num_outputs=2,
                init_hidden_layers=(),
                node_gene=BiasNode(
                    activation_options=[ACT.identity, ACT.inv, ACT.square, ACT.relu, ACT.lelu, ACT.square_root],
                    aggregation_options=[AGG.sum, AGG.product,AGG.mean],
                ),
                # output_transform=ACT.identity,
                output_transform=ACT.tanh,
            ),
        ),
        problem=custom(),
        generation_limit=100,
        fitness_target=-1e-6,
        seed=42,
    )

HyperNEAT:

algorithm=HyperNEAT(
            substrate=FullSubstrate(
                # input_coors=((-1, -1), (-0.33, -1), (0.33, -1), (1, -1)),
                input_coors=((-1, -1), (-0.5, -1), (0, -1), (0.5, -1), (1, -1)),
                hidden_coors=((-1, 0), (0, 0), (1, 0)),
                output_coors=((-0.5, 1), (0.5, 1)),
            ),
            neat=NEAT(
                pop_size=10000,
                species_size=20,
                survival_threshold=0.01,
                genome=DefaultGenome(
                    num_inputs=4,  # size of query coors
                    num_outputs=1,
                    init_hidden_layers=(),
                    output_transform=ACT.tanh,
                    node_gene=BiasNode(
                        activation_options=[ACT.identity, ACT.inv, ACT.square, ACT.relu, ACT.lelu, ACT.square_root],
                        aggregation_options=[AGG.sum, AGG.product,AGG.mean],
                    ),
                ),
            ),
            activation=ACT.tanh,
            activate_time=10,
            output_transform=ACT.sigmoid,
        ),
        problem=custom(),
        generation_limit=100,
        fitness_target=-1e-6,
    )

HyperNEATFeedForward:

pipeline = Pipeline(
        algorithm=HyperNEATFeedForward(
            substrate=MLPSubstrate(
                layers=[5, 5, 10, 5, 2], coor_range=(-5.0, 5.0, -5.0, 5.0)
            ),
            neat=NEAT(
                pop_size=10000,
                species_size=20,
                survival_threshold=0.01,
                genome=DefaultGenome(
                    num_inputs=4,  # size of query coors
                    num_outputs=1,
                    init_hidden_layers=(10,),
                    output_transform=ACT.tanh,
                    node_gene=BiasNode(
                        activation_options=[ACT.identity, ACT.inv, ACT.square, ACT.relu, ACT.lelu],
                        aggregation_options=[AGG.sum, AGG.product, AGG.mean],
                    ),
                ),
            ),
            activation=ACT.tanh,
            output_transform=ACT.sigmoid,
        ),
        problem=custom(),
        generation_limit=100,
        fitness_target=-1e-5,
    )

@WLS2002
Copy link
Collaborator

WLS2002 commented Feb 15, 2025

The difference between NEAT and HyperNEAT is as follows:

  1. NEAT evolves small networks and directly uses them as policies.
    HyperNEAT evolves small networks, but these small networks generate the weights for larger networks, which are then used as policies.
  2. The small networks in NEAT are flexible—each node can have different activation functions and aggregation functions. In contrast, in HyperNEAT, all nodes in the large network have fixed activation and aggregation functions.

From my understanding, NEAT is more suitable for small-scale problems (in terms of observation space and action space), like your custom problem, as it can evolve flexible and diverse small networks. On the other hand, HyperNEAT is more suitable for large-scale problems, such as image classification. Therefore, it's hard to say that "HyperNEAT and HyperNEATFeedForward were superior to NEAT"—they have different characteristics and are suited to different types of problems.

Regarding the parameter settings you used, there might be a potential issue:
For NEAT, you used:

output_transform=ACT.tanh

While for HyperNEAT, you used:

output_transform=ACT.sigmoid, # output transform in HyperNEAT

In TensorNEAT, output_transform refers to the activation function used in the last layer of the network. This means that when using ACT.tanh, the network's output is constrained within [-1, 1], whereas using ACT.sigmoid constrains the output within [0, 1]. This difference might lead to significant performance discrepancies between the two algorithms on the same task.

Hope this can help you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants