Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How did you deal with DWT shape differences? #3

Open
zaptrem opened this issue Aug 19, 2024 · 4 comments
Open

How did you deal with DWT shape differences? #3

zaptrem opened this issue Aug 19, 2024 · 4 comments

Comments

@zaptrem
Copy link

zaptrem commented Aug 19, 2024

screenshot

Each band that comes out of DWT has a different length. How did you fit them all into the model where all inputs have to be the same length? Can you explain the actual shape changes a waveform goes through as it makes its way to and from your model? e.g., Wave B, T -> ??? -> Wave B, T? The diagram in the paper is unclear since DWT spits out lots of different sequence lengths.

@sh-lee-prml
Copy link
Owner

sh-lee-prml commented Aug 20, 2024

Thanks for your interest.

[B,1,T] --> DWT --> [B,4,T//4]

We utilize the second dim as each target dwt components

[B,DWT_i:DWT_i+1,T//4] is x1 for each model.

For progressive generation that generates lower band first,

we could condition the DWT components of previous bands together [B,:DWT_i,T//4]

So, the input of networks for training is

Concat(x1 of [B,:DWT_i,T//4], xt of [B,DWT_i:DWT_i+1,T//4], dim=1)

The output of the network is the vector field of xt for [B,DWT_i:DWT_i+1,T//4].

For inference, the gt x1 of [B,:DWT_i,T//4] will be replaced with the generated dwt components.

For waveform reconstruction,

[B,4,T//4] --> iDWT --> [B,1,T]

Thanks!

@zaptrem
Copy link
Author

zaptrem commented Aug 20, 2024

Thanks for the prompt response! Doesn't DWT output something like the below shape (since by definition it trades off time and frequency resolution)? T is different for all of them. Below is from pytorch_wavelets

import torch
from pytorch_wavelets import DWT1DForward, DWT1DInverse  # or simply DWT1D, IDWT1D
dwt = DWT1DForward(wave='db6', J=3)
X = torch.randn(10, 5, 100)
yl, yh = dwt(X)
print(yl.shape)
>>> torch.Size([10, 5, 22])
print(yh[0].shape)
>>> torch.Size([10, 5, 55])
print(yh[1].shape)
>>> torch.Size([10, 5, 33])
print(yh[2].shape)
>>> torch.Size([10, 5, 22])
idwt = DWT1DInverse(wave='db6')
x = idwt((yl, yh))

Which DWT library are you using?

@sh-lee-prml
Copy link
Owner

import torch
from pytorch_wavelets import DWT1DForward, DWT1DInverse  # or simply DWT1D, IDWT1D

y = torch.randn(10, 1, 128)

dwt = DWT1DForward()
idwt = DWT1DInverse()

x_dwt1, x_dwt2 = dwt(y)
x_dwt1_a, x_dwt1_b = dwt(x_dwt1)
x_dwt2_a, x_dwt2_b = dwt(x_dwt2[0])

print(x_dwt1_a.shape)
print(x_dwt1_b[0].shape)
print(x_dwt2_a.shape)
print(x_dwt2_b[0].shape)

x1 = torch.concat([x_dwt1_a, x_dwt1_b[0], x_dwt2_a, x_dwt2_b[0]], dim=1)

print(x1.shape)

x_low = idwt([x_dwt1_a, [x_dwt1_b[0]]])
x_high = idwt([x_dwt2_a, [x_dwt2_b[0]]])
x = idwt([x_low, [x_high]])

print(x.shape)

Try this, and Time T//4 should be 0. You can use zero-padding.

@zaptrem
Copy link
Author

zaptrem commented Aug 20, 2024

Ah, that makes sense. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants