Skip to content

Commit a3e8d3f

Browse files
Jwaminjua-r-r-o-w
andauthored
[docs] refactoring docstrings in models/embeddings_flax.py (#9592)
* [docs] refactoring docstrings in `models/embeddings_flax.py` * Update src/diffusers/models/embeddings_flax.py * make style --------- Co-authored-by: Aryan <[email protected]>
1 parent fff4be8 commit a3e8d3f

File tree

1 file changed

+23
-9
lines changed

1 file changed

+23
-9
lines changed

src/diffusers/models/embeddings_flax.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,21 @@ def get_sinusoidal_embeddings(
2929
"""Returns the positional encoding (same as Tensor2Tensor).
3030
3131
Args:
32-
timesteps: a 1-D Tensor of N indices, one per batch element.
33-
These may be fractional.
34-
embedding_dim: The number of output channels.
35-
min_timescale: The smallest time unit (should probably be 0.0).
36-
max_timescale: The largest time unit.
32+
timesteps (`jnp.ndarray` of shape `(N,)`):
33+
A 1-D array of N indices, one per batch element. These may be fractional.
34+
embedding_dim (`int`):
35+
The number of output channels.
36+
freq_shift (`float`, *optional*, defaults to `1`):
37+
Shift applied to the frequency scaling of the embeddings.
38+
min_timescale (`float`, *optional*, defaults to `1`):
39+
The smallest time unit used in the sinusoidal calculation (should probably be 0.0).
40+
max_timescale (`float`, *optional*, defaults to `1.0e4`):
41+
The largest time unit used in the sinusoidal calculation.
42+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
43+
Whether to flip the order of sinusoidal components to cosine first.
44+
scale (`float`, *optional*, defaults to `1.0`):
45+
A scaling factor applied to the positional embeddings.
46+
3747
Returns:
3848
a Tensor of timing signals [N, num_channels]
3949
"""
@@ -61,9 +71,9 @@ class FlaxTimestepEmbedding(nn.Module):
6171
6272
Args:
6373
time_embed_dim (`int`, *optional*, defaults to `32`):
64-
Time step embedding dimension
65-
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
66-
Parameters `dtype`
74+
Time step embedding dimension.
75+
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
76+
The data type for the embedding parameters.
6777
"""
6878

6979
time_embed_dim: int = 32
@@ -83,7 +93,11 @@ class FlaxTimesteps(nn.Module):
8393
8494
Args:
8595
dim (`int`, *optional*, defaults to `32`):
86-
Time step embedding dimension
96+
Time step embedding dimension.
97+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
98+
Whether to flip the sinusoidal function from sine to cosine.
99+
freq_shift (`float`, *optional*, defaults to `1`):
100+
Frequency shift applied to the sinusoidal embeddings.
87101
"""
88102

89103
dim: int = 32

0 commit comments

Comments
 (0)