Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to execute simple finetune examples due to jnp.linalg.norm error #4

Open
snerligit opened this issue Feb 16, 2023 · 2 comments

Comments

@snerligit
Copy link

Please see the error I am facing while running simple finetune command given in Readme:

`done importing
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
binder_intercepts: [0.80367635, 0.43373787]
cmd: /gstore/scratch/u/nerlis/alphafold_finetune/run_finetuning.py --data_dir /gstore/scratch/u/nerlis/alphafold_finetune/alphafold/data --binder_intercepts 0.80367635 --binder_intercepts 0.43373787 --freeze_binder --train_dataset /gstore/scratch/u/nerlis/alphafold_finetune/examples/tiny_pmhc_finetune/tiny_example_train.tsv --valid_dataset /gstore/scratch/u/nerlis/alphafold_finetune/examples/tiny_pmhc_finetune/tiny_example_valid.tsv
local_device: cpu ng033
model_name: model_2_ptm
outprefix: testrun
WARNING:tensorflow:From /gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. tensorflow/tensorflow#56089
pae: [[0.]]
initial binder params: {'PaeBinderClassifier': {'x_intercept': DeviceArray([[0.80367637, 0.43373787]], dtype=float32), 'slope': DeviceArray([-7.901963], dtype=float32)}}
create_batch_from_dataset_row: DRB1_0901_True_SVAYKAAVGATPEAK_2 unk
WARNING chainbreak: B 29 31 4.174218489729545 examples/tiny_pmhc_finetune/templates/6qzc_MH2_DRA_01010101_DRB1_01010101.pdb
WARNING chainbreak: A 74 275 11.303902954289725 examples/tiny_pmhc_finetune/natives/run135_batch_0660DRB1_0901_True_SVAYKAAVGATPEAK_2_model_1_model_1.pdb
WARNING chainbreak: A 359 560 17.32551615392742 examples/tiny_pmhc_finetune/natives/run135_batch_0660DRB1_0901_True_SVAYKAAVGATPEAK_2_model_1_model_1.pdb
train_epoch: 0 batch: 0
binder_params: {'PaeBinderClassifier': {'slope': DeviceArray([-7.901963], dtype=float32), 'x_intercept': DeviceArray([[0.80367637, 0.43373787]], dtype=float32)}}
not setting num_iter_recycling!!! will do 3 recycles
2023-02-15 16:04:11.006011: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:65]


[Compiling module pmap_train_step] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.


2023-02-15 16:09:49.975817: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 7m38.970211666s


[Compiling module pmap_train_step] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.


pae: Traced<ShapedArray(float32[1])>with<JVPTrace(level=2/1)> with
primal = Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=0/1)>
tangent = Traced<ShapedArray(float32[1])>with<JaxprTrace(level=1/1)> with
pval = (ShapedArray(float32[1]), None)
recipe = JaxprEqnRecipe(eqn_id=<object object at 0x2aacf2d691e0>, in_tracers=(Traced<ShapedArray(float32[]):JaxprTrace(level=1/1)>, Traced<ShapedArray(float32[1]):JaxprTrace(level=1/1)>), out_tracer_refs=[<weakref at 0x2aad0069b540; to 'JaxprTracer' at 0x2aad0069bd60>], out_avals=[ShapedArray(float32[1])], primitive=xla_call, params={'device': None, 'backend': None, 'name': 'true_divide', 'donated_invars': (False, False), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[] b:f32[1]. let c:f32[1] = div b a in (c,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x2aad04505d30>, name_stack=NameStack(stack=(Transform(name='jvp'), Scope(name='PaeBinderClassifier')))))
train_epoch_n= 0 0 loss= 0.84698945 structure_flag: False lddt_ca= 1.0 fape= 0.04288465 binder_probs= [[0.4951988 0.5048012]] binder_loss= [0.68359053] peptide_plddt= 108.62619 binder_features= [[0.4313074]] binder_labels= [array([[0., 1.]])] binder_params= {'PaeBinderClassifier': {'slope': DeviceArray([-7.901963], dtype=float32), 'x_intercept': DeviceArray([[0.80367637, 0.43373787]], dtype=float32)}}
grad accumulate: 1 0
grad update! 1 1
Traceback (most recent call last):
File "/gstore/scratch/u/nerlis/alphafold_finetune/run_finetuning.py", line 786, in
grads_sum = norm_grads_per_example(grads_sum,
File "/gstore/scratch/u/nerlis/alphafold_finetune/run_finetuning.py", line 380, in norm_grads_per_example
total_grad_norm = jnp.linalg.norm([jnp.linalg.norm(neg.ravel()) for neg in nonempty_grads])
File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/_src/api.py", line 622, in cache_miss
execute = dispatch.xla_call_impl_lazy(fun, *tracers, **params)
File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/_src/dispatch.py", line 236, in _xla_call_impl_lazy
return xla_callable(fun, device, backend, name, donated_invars, keep_unused,
File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/linear_util.py", line 303, in memoized_fun
ans = call(fun, *args)
File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/_src/dispatch.py", line 359, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/_src/dispatch.py", line 445, in lower_xla_callable
jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2077, in trace_to_jaxpr_final2
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2027, in trace_to_subjaxpr_dynamic2
ans = fun.call_wrapped(*in_tracers
)
File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/linear_util.py", line 167, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/_src/numpy/linalg.py", line 490, in norm
_check_arraylike("jnp.linalg.norm", x)
File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/_src/numpy/util.py", line 345, in _check_arraylike
raise TypeError(msg.format(fun_name, type(arg), pos))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: jnp.linalg.norm requires ndarray or scalar arguments, got <class 'list'> at position 0.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/gstore/scratch/u/nerlis/alphafold_finetune/run_finetuning.py", line 786, in
grads_sum = norm_grads_per_example(grads_sum,
File "/gstore/scratch/u/nerlis/alphafold_finetune/run_finetuning.py", line 380, in norm_grads_per_example
total_grad_norm = jnp.linalg.norm([jnp.linalg.norm(neg.ravel()) for neg in nonempty_grads])
File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/_src/numpy/linalg.py", line 490, in norm
_check_arraylike("jnp.linalg.norm", x)
File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/_src/numpy/util.py", line 345, in _check_arraylike
raise TypeError(msg.format(fun_name, type(arg), pos))
TypeError: jnp.linalg.norm requires ndarray or scalar arguments, got <class 'list'> at position 0.`

@ljq0811
Copy link

ljq0811 commented Sep 14, 2023

I have the same problem, have you solved it please?

@phbradley
Copy link
Owner

Hi there,
I'm sorry for missing this issue originally, and I appreciate that installing the various dependencies is challenging. Two things might help:

Thanks,
Phil

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants