-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AttributeError: 'numpy.ndarray' object has no attribute 'at' #17
Comments
The query import jax.numpy as jnp
self.coors = jnp.array(coors) # jnp instead of np
self.nodes = jnp.array(nodes) # jnp instead of np
self.conns = jnp.array(conns) # jnp instead of np
def make_conns(self, query_res):
return self.conns.at[:, -1].set(query_res) Still results in an error:
|
This is quite strange. I suspect this issue was caused by a JAX update. I have now fixed the problem. Thank you very much for using TensorNEAT and for reporting this issue! If you encounter any other problems, I’d be happy to help. |
Thank you very much for your work. Additionally, I would like to ask: I am planning to use HyperNEAT for data prediction, where the input consists of 4 features, the output consists of 2 predicted values, and there are a total of 100,000 data samples. Would it be sufficient to modify the data code in
|
It seems like you want to use HyperNEAT to solve a problem which requires 4 inputs and 2 outputs. In that case, you also need to modify the Substrate settings in You need to make the your substrate have 5 (4 + 1, 4 is for inputs and extra 1 is for bias) You can use from tensorneat.algorithm.hyperneat import MLPSubstrate
substrate=MLPSubstrate(
layers=[5, 5, 5, 5, 2], coor_range=(-5.0, 5.0, -5.0, 5.0)
) Or use from tensorneat.algorithm.hyperneat import FullSubstrate
substrate=FullSubstrate(
input_coors=((-1, -1), (-0.5, -1), (0, -1), (0.5, -1), (1, -1)),
hidden_coors=((-1, 0), (0, 0), (1, 0)),
output_coors=((-0.5, 1), (0.5, 1)),
) P.S. You may need to adjust the concrete structure of MLP or coordinates to fit you problem. Then, you need to modify the problem that you have already done: problem=custom(), Hope this can help you! |
Hi, Thanks for the detailed reply. I have encountered two more issues:
However, after 100 generations, the maximum number of nodes used is only 18. How can I generate a more complex network with more nodes? Here's the output:
|
from tensorneat.genome import DefaultMutation
neat = NEAT(
pop_size=10000,
species_size=20,
survival_threshold=0.01,
genome=DefaultGenome(
num_inputs=4,
num_outputs=1,
init_hidden_layers=(10,), # default is (), which means no hidden nodes
mutation=DefaultMutation(
node_add=0.2, # default is 0.1
conn_add=0.4, # default is 0.2
),
),
) |
Thank you for your reply. Does the framework support batch inference? Currently, I manually split the data into multiple batches for inference. Is there a simpler way to do this?
|
It seems that you have already done this. TensorNEAT support batch inference in both population and datapoint demension. I think your code is simple enough for batch evaluating. Hope it works! P.s. For loop in jax may causes large compile time, you can use loop in JAX such as |
Thank you for your reminder. After using
HyperNEAT:
HyperNEATFeedForward:
|
The difference between NEAT and HyperNEAT is as follows:
From my understanding, NEAT is more suitable for small-scale problems (in terms of observation space and action space), like your custom problem, as it can evolve flexible and diverse small networks. On the other hand, HyperNEAT is more suitable for large-scale problems, such as image classification. Therefore, it's hard to say that "HyperNEAT and HyperNEATFeedForward were superior to NEAT"—they have different characteristics and are suited to different types of problems. Regarding the parameter settings you used, there might be a potential issue: output_transform=ACT.tanh While for HyperNEAT, you used: output_transform=ACT.sigmoid, # output transform in HyperNEAT In TensorNEAT, Hope this can help you! |
I installed Python 3.10 and followed the tutorial for installation. When running
examples/func_fit/xor_hyperneat.py
, I encountered the following error:It seems like the issue is related to the version of NumPy, but I have tried both 1.26.2 and 2.0.1, and I still get the error. Below are the versions of some of the packages in my environment:
The text was updated successfully, but these errors were encountered: