@@ -29,11 +29,21 @@ def get_sinusoidal_embeddings(
29
29
"""Returns the positional encoding (same as Tensor2Tensor).
30
30
31
31
Args:
32
- timesteps: a 1-D Tensor of N indices, one per batch element.
33
- These may be fractional.
34
- embedding_dim: The number of output channels.
35
- min_timescale: The smallest time unit (should probably be 0.0).
36
- max_timescale: The largest time unit.
32
+ timesteps (`jnp.ndarray` of shape `(N,)`):
33
+ A 1-D array of N indices, one per batch element. These may be fractional.
34
+ embedding_dim (`int`):
35
+ The number of output channels.
36
+ freq_shift (`float`, *optional*, defaults to `1`):
37
+ Shift applied to the frequency scaling of the embeddings.
38
+ min_timescale (`float`, *optional*, defaults to `1`):
39
+ The smallest time unit used in the sinusoidal calculation (should probably be 0.0).
40
+ max_timescale (`float`, *optional*, defaults to `1.0e4`):
41
+ The largest time unit used in the sinusoidal calculation.
42
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
43
+ Whether to flip the order of sinusoidal components to cosine first.
44
+ scale (`float`, *optional*, defaults to `1.0`):
45
+ A scaling factor applied to the positional embeddings.
46
+
37
47
Returns:
38
48
a Tensor of timing signals [N, num_channels]
39
49
"""
@@ -61,9 +71,9 @@ class FlaxTimestepEmbedding(nn.Module):
61
71
62
72
Args:
63
73
time_embed_dim (`int`, *optional*, defaults to `32`):
64
- Time step embedding dimension
65
- dtype (:obj: `jnp.dtype`, *optional*, defaults to jnp.float32):
66
- Parameters `dtype`
74
+ Time step embedding dimension.
75
+ dtype (`jnp.dtype`, *optional*, defaults to ` jnp.float32` ):
76
+ The data type for the embedding parameters.
67
77
"""
68
78
69
79
time_embed_dim : int = 32
@@ -83,7 +93,11 @@ class FlaxTimesteps(nn.Module):
83
93
84
94
Args:
85
95
dim (`int`, *optional*, defaults to `32`):
86
- Time step embedding dimension
96
+ Time step embedding dimension.
97
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
98
+ Whether to flip the sinusoidal function from sine to cosine.
99
+ freq_shift (`float`, *optional*, defaults to `1`):
100
+ Frequency shift applied to the sinusoidal embeddings.
87
101
"""
88
102
89
103
dim : int = 32
0 commit comments