Skip to content

Commit 3892635

Browse files
[Term Entry] PyTorch Tensor Operations: .take_along_dim()
* Add .take_along_dim() documentation entry under PyTorch * Enhance documentation for .take_along_dim() in PyTorch with detailed descriptions, examples, and key features * Fix formatting inconsistencies in take_along_dim.md examples * Improve documentation for .take_along_dim() with enhanced examples and other content * Refine documentation for .take_along_dim() by clarifying parameters, returns, and enhancing examples based on the comments from dakshdeepHERE * Improve formatting of parameters in .take_along_dim() documentation for clarity * Update take-along-dim.md * Minor changes ---------
1 parent 7d4103e commit 3892635

File tree

1 file changed

+117
-0
lines changed

1 file changed

+117
-0
lines changed
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
---
2+
Title: '.take_along_dim()'
3+
Description: 'Select elements from a tensor along a specified dimension using indices.'
4+
Subjects:
5+
- 'Data Science'
6+
- 'AI'
7+
Tags:
8+
- 'Tensor'
9+
- 'PyTorch'
10+
- 'Deep Learning'
11+
- 'Neural Networks'
12+
CatalogContent:
13+
- 'intro-to-py-torch-and-neural-networks'
14+
- 'paths/machine-learning'
15+
---
16+
17+
The **`.take_along_dim()`** function in PyTorch is used to select elements from a [tensor](https://www.codecademy.com/resources/docs/pytorch/tensors) along a specified dimension. This operation is essential for advanced indexing operations and manipulating multi-dimensional tensors in deep learning applications.
18+
19+
Similar to [`.take()`](https://www.codecademy.com/resources/docs/pytorch/tensor-operations/take), which extracts elements based on indices and always returns a 1D tensor, `.take_along_dim()` provides a more flexible approach by allowing indexing along a specific dimension while preserving the tensor's shape.
20+
21+
## Syntax
22+
23+
```pseudo
24+
torch.take_along_dim(input, indices, dim)
25+
```
26+
27+
- `input`: The source tensor from which elements will be selected.
28+
- `indices`: A tensor of indices specifying which elements to select along the specified dimension.
29+
- `dim`: The dimension along which to perform the selection.
30+
31+
It returns a new tensor with the same dimensionality as the input tensor, containing the selected elements.
32+
33+
## Example
34+
35+
Here is a basic usage example of `.take_along_dim()` in PyTorch to select elements along a specific dimension:
36+
37+
```py
38+
import torch
39+
40+
# Create a source tensor
41+
input_tensor = torch.tensor([[10, 20, 30],
42+
[40, 50, 60]])
43+
44+
# Define indices for selection
45+
indices = torch.tensor([[2, 1, 0],
46+
[1, 0, 2]])
47+
48+
# Select elements along dimension 1
49+
result = torch.take_along_dim(input_tensor, indices, dim=1)
50+
51+
# Print the result
52+
print(result)
53+
```
54+
55+
The following will be the output of the above code:
56+
57+
```shell
58+
tensor([[30, 20, 10],
59+
[50, 40, 60]])
60+
```
61+
62+
Moreover, the function can also be used to select elements along a specific dimension in a multi-dimensional tensor. For instance, the following example can be considered:
63+
64+
```py
65+
import torch
66+
67+
# Create a 3D tensor
68+
input_tensor = torch.tensor([[[1, 2], [3, 4]],
69+
[[5, 6], [7, 8]]])
70+
71+
# Define indices for selection
72+
indices = torch.tensor([[[0, 1], [1, 0]],
73+
[[0, 0], [1, 1]]])
74+
75+
# Select elements along the last dimension
76+
result = torch.take_along_dim(input_tensor, indices, dim=2)
77+
78+
# Print the result
79+
print(result)
80+
```
81+
82+
The output of the above code will be:
83+
84+
```shell
85+
tensor([[[1, 2],
86+
[4, 3]],
87+
88+
[[5, 5],
89+
[8, 8]]])
90+
```
91+
92+
## Key Features
93+
94+
Here are some key features of the `.take_along_dim()` function:
95+
96+
- Preserves tensor dimensionality during selection
97+
- Supports batch operations
98+
- Works with any number of dimensions
99+
- Maintains gradient information for backpropagation
100+
101+
## Common Use Cases
102+
103+
Here are some common use cases of the `.take_along_dim()` function:
104+
105+
- Sorting tensor elements
106+
- Implementing attention mechanisms
107+
- Selecting top-k elements
108+
- Custom pooling operations
109+
110+
## Notes
111+
112+
Here are some notes about the `.take_along_dim()` function:
113+
114+
- The indices tensor must have the same shape as the input tensor
115+
- Supports automatic differentiation
116+
- More flexible than the `.take()` function for multi-dimensional operations
117+
- Memory-efficient for large tensor operations

0 commit comments

Comments
 (0)