Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions include/infinicore/ops/avg_pool3d.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"
#include <tuple>

namespace infinicore::op {
class AvgPool3d {
public:
using schema = void (*)(Tensor, Tensor, std::tuple<size_t, size_t, size_t>, std::tuple<size_t, size_t, size_t>, std::tuple<size_t, size_t, size_t>, bool);
static void execute(Tensor output, Tensor input, std::tuple<size_t, size_t, size_t> kernel_size, std::tuple<size_t, size_t, size_t> stride, std::tuple<size_t, size_t, size_t> padding, bool ceil_mode);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor avg_pool3d(Tensor input, std::tuple<size_t, size_t, size_t> kernel_size, std::tuple<size_t, size_t, size_t> stride, std::tuple<size_t, size_t, size_t> padding, bool ceil_mode);
void avg_pool3d_(Tensor output, Tensor input, std::tuple<size_t, size_t, size_t> kernel_size, std::tuple<size_t, size_t, size_t> stride, std::tuple<size_t, size_t, size_t> padding, bool ceil_mode);
} // namespace infinicore::op
16 changes: 16 additions & 0 deletions include/infinicore/ops/dot.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class Dot {
public:
using schema = void (*)(Tensor, Tensor, Tensor);
static void execute(Tensor c, Tensor a, Tensor b);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor dot(Tensor a, Tensor b);
void dot_(Tensor c, Tensor a, Tensor b);
} // namespace infinicore::op
16 changes: 16 additions & 0 deletions include/infinicore/ops/histc.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class Histc {
public:
using schema = void (*)(Tensor, Tensor, size_t, double, double);
static void execute(Tensor input, Tensor output, size_t bins, double min, double max);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor histc(Tensor input, size_t bins, double min, double max);
void histc_(Tensor input, Tensor output, size_t bins, double min, double max);
} // namespace infinicore::op
16 changes: 16 additions & 0 deletions include/infinicore/ops/log10.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class Log10 {
public:
using schema = void (*)(Tensor, Tensor);
static void execute(Tensor output, Tensor input);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor log10(Tensor input);
void log10_(Tensor output, Tensor input);
} // namespace infinicore::op
16 changes: 16 additions & 0 deletions include/infinicore/ops/log1p.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class Log1p {
public:
using schema = void (*)(Tensor, Tensor);
static void execute(Tensor output, Tensor input);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor log1p(Tensor input);
void log1p_(Tensor output, Tensor input);
} // namespace infinicore::op
15 changes: 15 additions & 0 deletions include/infinicore/ops/zeros.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#include "common/op.hpp"

namespace infinicore::op {
class Zeros {

public:
using schema = void (*)(Tensor);
static void execute(Tensor output);
static common::OpDispatcher<schema> &dispatcher();
};

void zeros_(Tensor output);
} // namespace infinicore::op
8 changes: 8 additions & 0 deletions python/infinicore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
from infinicore.ops.mul import mul
from infinicore.ops.narrow import narrow
from infinicore.ops.rearrange import rearrange
from infinicore.ops.dot import dot
from infinicore.ops.histc import histc
from infinicore.ops.log10 import log10
from infinicore.ops.log1p import log1p
from infinicore.tensor import (
Tensor,
empty,
Expand Down Expand Up @@ -115,6 +119,10 @@
"strided_empty",
"strided_from_blob",
"zeros",
"dot",
"log10",
"log1p",
"histc",
]

use_ntops = False
Expand Down
3 changes: 2 additions & 1 deletion python/infinicore/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from infinicore.nn import functional
from infinicore.nn.modules import * # noqa: F403
from infinicore.nn.parameter import InfiniCoreParameter as Parameter
from infinicore.nn import init

__all__ = ["functional", "Parameter"]
__all__ = ["functional", "Parameter", "init"]
2 changes: 2 additions & 0 deletions python/infinicore/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .rope import RopeAlgo, rope
from .silu import silu
from .swiglu import swiglu
from .avg_pool3d import avg_pool3d

__all__ = [
"causal_softmax",
Expand All @@ -17,4 +18,5 @@
"embedding",
"rope",
"RopeAlgo",
"avg_pool3d",
]
68 changes: 68 additions & 0 deletions python/infinicore/nn/functional/avg_pool3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import infinicore
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def _zeros_pad(input: Tensor, padding: tuple[int, ...]) -> Tensor:
r"""Pad a tensor.

Args:
input (Tensor): The input tensor.
padding (tuple[int, ...]): The padding sizes.

Returns:
Tensor: The padded tensor.
"""
output_shape = []
for i in range(input.ndim):
output_shape.append(input.size(i) + 2 * padding[i])

output = infinicore.empty(output_shape, dtype=input.dtype, device=input.device)
output = infinicore.nn.init.zeros_(output)

# 使用 narrow 函数获取对应的位置,然后复制数据
# 需要逐维度进行 narrow 操作
output_view = output
for dim in range(len(input.size())):
output_view = infinicore.narrow(output_view, dim, padding[dim], input.size(dim))

# 将输入数据复制到输出张量的对应位置
infinicore.add(input, output_view, out=output_view)

return output


def avg_pool3d(
input: Tensor,
kernel_size: tuple[int, int, int] | int,
stride: tuple[int, int, int] | int | None = None,
padding: tuple[int, int, int] | int = 0,
ceil_mode: bool = False,
):
r"""Applies a 3D average pooling over an input signal composed of several input
planes."""
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)

if isinstance(stride, int):
stride = (stride, stride, stride)

if stride is None:
stride = kernel_size

if isinstance(padding, int):
padding = [padding, padding, padding]

if infinicore.use_ntops and input.device.type in ("cuda", "musa"):
padding = [0, 0] + list(padding)

if any(p > 0 for p in padding):
input = _zeros_pad(input, padding)
return infinicore.ntops.torch.avg_pool3d(input, kernel_size, stride, ceil_mode)

# cpu infer
return Tensor(
_infinicore.avg_pool3d(
input._underlying, kernel_size, stride, padding, ceil_mode
)
)
5 changes: 5 additions & 0 deletions python/infinicore/nn/init/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .zeros_ import zeros_

__all__ = [
"zeros_",
]
9 changes: 9 additions & 0 deletions python/infinicore/nn/init/zeros_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import infinicore
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def zeros_(input: Tensor) -> Tensor:
r"""Fill the input tensor with the scalar value 0."""
_infinicore.zeros_(input._underlying)
return input
16 changes: 16 additions & 0 deletions python/infinicore/ops/dot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import infinicore
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def dot(input: Tensor, tensor: Tensor, *, out=None) -> Tensor:
r"""Compute the dot product of two 1-D tensors."""

if infinicore.use_ntops and input.device.type in ("cuda", "musa"):
return infinicore.ntops.torch.dot(input, tensor, out=out)

if out is None:
return Tensor(_infinicore.dot(input._underlying, tensor._underlying))

_infinicore.dot_(out._underlying, input._underlying, tensor._underlying)
return out
13 changes: 13 additions & 0 deletions python/infinicore/ops/histc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import infinicore
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def histc(input: Tensor, bins: int = 100, min: float | None = None, max: float | None = None) -> Tensor:
r"""Apply the logsumexp function."""

if infinicore.use_ntops and input.device.type in ("cuda", "musa"):
is_moore = input._underlying.device.type == _infinicore.Device.Type.MOORE
return infinicore.ntops.torch.histc(input, bins, min, max, is_moore)

return Tensor(_infinicore.histc(input._underlying, bins, min, max))
16 changes: 16 additions & 0 deletions python/infinicore/ops/log10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import infinicore
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def log10(input: Tensor, *, out=None) -> Tensor:
r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise."""

if infinicore.use_ntops and input.device.type in ("cuda", "musa"):
return infinicore.ntops.torch.log10(input, out=out)

if out is None:
return Tensor(_infinicore.log10(input._underlying))

_infinicore.log10_(out._underlying, input._underlying)
return out
16 changes: 16 additions & 0 deletions python/infinicore/ops/log1p.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import infinicore
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def log1p(input: Tensor, *, out=None) -> Tensor:
r"""Compute the ln(x + 1)."""

if infinicore.use_ntops and input.device.type in ("cuda", "musa"):
return infinicore.ntops.torch.log1p(input, out=out)

if out is None:
return Tensor(_infinicore.log1p(input._underlying))

_infinicore.log1p_(out._underlying, input._underlying)
return out
66 changes: 66 additions & 0 deletions src/infinicore/ops/avg_pool3d/avg_pool3d.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#include "infinicore/ops/avg_pool3d.hpp"
#include <cmath>
#include <stdexcept>

namespace infinicore::op {

common::OpDispatcher<AvgPool3d::schema> &AvgPool3d::dispatcher() {
static common::OpDispatcher<AvgPool3d::schema> dispatcher_;
return dispatcher_;
};

void AvgPool3d::execute(Tensor output, Tensor input, std::tuple<size_t, size_t, size_t> kernel_size, std::tuple<size_t, size_t, size_t> stride, std::tuple<size_t, size_t, size_t> padding, bool ceil_mode) {
infinicore::context::setDevice(input->device(), true);
auto device_type = context::getDevice().getType();
auto func = dispatcher().lookup(device_type);

if (func == nullptr) {
throw std::runtime_error("No AvgPool3d implementation found for device type: " + std::to_string(static_cast<int>(device_type)));
}

func(output, input, kernel_size, stride, padding, ceil_mode);
}

Tensor avg_pool3d(Tensor input, std::tuple<size_t, size_t, size_t> kernel_size, std::tuple<size_t, size_t, size_t> stride, std::tuple<size_t, size_t, size_t> padding, bool ceil_mode) {
const auto ndim = input->ndim();
auto input_shape = input->shape();

if (ndim != 5 && ndim != 4) {
throw std::runtime_error("Input tensor must be 4-dimensional (N, C, D_in, H_in, W_in) or (C, D_in, H_in, W_in)");
}

if (ndim == 4) {
input = input->view({1, input_shape[0], input_shape[1], input_shape[2], input_shape[3]});
input_shape = input->shape();
}

const auto [Kernel_D, Kernel_H, Kernel_W] = kernel_size;
const auto [Stride_D, Stride_H, Stride_W] = stride;
const auto [Padding_D, Padding_H, Padding_W] = padding;
const auto D_in = input_shape[2];
const auto H_in = input_shape[3];
const auto W_in = input_shape[4];
size_t D_out = 0;
size_t H_out = 0;
size_t W_out = 0;
if (ceil_mode) {
D_out = static_cast<size_t>(std::ceil(static_cast<float>(D_in + 2 * Padding_D - Kernel_D) / Stride_D)) + 1;
H_out = static_cast<size_t>(std::ceil(static_cast<float>(H_in + 2 * Padding_H - Kernel_H) / Stride_H)) + 1;
W_out = static_cast<size_t>(std::ceil(static_cast<float>(W_in + 2 * Padding_W - Kernel_W) / Stride_W)) + 1;
} else {
D_out = static_cast<size_t>(std::floor(static_cast<float>(D_in + 2 * Padding_D - Kernel_D) / Stride_D)) + 1;
H_out = static_cast<size_t>(std::floor(static_cast<float>(H_in + 2 * Padding_H - Kernel_H) / Stride_H)) + 1;
W_out = static_cast<size_t>(std::floor(static_cast<float>(W_in + 2 * Padding_W - Kernel_W) / Stride_W)) + 1;
}

auto output_shape = Shape{input_shape[0], input_shape[1], D_out, H_out, W_out};

auto output = Tensor::empty(output_shape, input->dtype(), input->device());
avg_pool3d_(output, input, kernel_size, stride, padding, ceil_mode);
return output;
}

void avg_pool3d_(Tensor output, Tensor input, std::tuple<size_t, size_t, size_t> kernel_size, std::tuple<size_t, size_t, size_t> stride, std::tuple<size_t, size_t, size_t> padding, bool ceil_mode) {
AvgPool3d::execute(output, input, kernel_size, stride, padding, ceil_mode);
}
} // namespace infinicore::op
Loading