Skip to content

[Term Entry]PyTorch Split_Tensor .tensor_split() - Requested Edits #5952

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 18, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
---
Title: '.tensor_split()'
Description: 'Splits a tensor into multiple sub-tensors along a specified dimension, based on either specified indices or the number of equal parts.'
Subjects:
- 'AI'
- 'Data Science'
Tags:
- 'Data Structures'
- 'Deep Learning'
- 'PyTorch'
- 'Tensor'
CatalogContent:
- 'intro-to-py-torch-and-neural-networks'
- 'paths/data-science'
---

In PyTorch, the **`.tensor_split()`** function splits a tensor into multiple sub-tensors along a specified dimension. If the tensor cannot be split evenly, the function distributes the elements across the sub-tensors as evenly as possible.

## Syntax

```pseudo
torch.tensor_split(input, indices_or_sections, dim=0)
```

- `input`: The tensor to be split.
- `indices_or_sections`:
- If _int_: The number of sub-tensors to split the input tensor into. If the split is uneven, the resulting sub-tensors will differ in size to distribute elements as evenly as possible.
- If _list or tuple of ints_: The indices at which to split the tensor along the specified dimension.
- `dim`: The dimension along which to split the tensor. Default is `0`.

## Example

The following example demonstrates the use of the `.tensor_split()` function:

```py
import torch

# Create a one-dimensional tensor
x = torch.arange(10)

# Split the tensor into 2 parts
result = torch.tensor_split(x, 2)

# Print the result
print(result)
```

The code above gives the output as follows:

```shell
(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9]))
```

The output is a list of two sub-tensors, where the input tensor is evenly split into two parts along its only dimension.
Loading