Skip to content

Wrappers for additional tblis functions #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[build-system]
requires = ["setuptools >= 64"]
build-backend = "setuptools.build_meta"
31 changes: 18 additions & 13 deletions pyscf/tblis_einsum/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,30 @@ if (CMAKE_COMPILER_IS_GNUCC) # Does it skip the link flag on old OsX?
endif()
endif()

# See also https://gitlab.kitware.com/cmake/community/wikis/doc/cmake/RPATH-handling
if (WIN32)
#?
elseif (APPLE)
set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
set(CMAKE_INSTALL_RPATH "@loader_path")
set(CMAKE_BUILD_RPATH "@loader_path")
else ()
set(CMAKE_SKIP_BUILD_RPATH True)
set(CMAKE_BUILD_WITH_INSTALL_RPATH True)
set(CMAKE_INSTALL_RPATH "\$ORIGIN")
endif ()
option(VENDOR_TBLIS "Download and build tblis" on)

if(VENDOR_TBLIS)
# The following is needed because TBLIS will be installed in the same folder
# as the built CPython extension.
# See also https://gitlab.kitware.com/cmake/community/wikis/doc/cmake/RPATH-handling
if (WIN32)
#?
elseif (APPLE)
set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
set(CMAKE_INSTALL_RPATH "@loader_path")
set(CMAKE_BUILD_RPATH "@loader_path")
else ()
set(CMAKE_SKIP_BUILD_RPATH True)
set(CMAKE_BUILD_WITH_INSTALL_RPATH True)
set(CMAKE_INSTALL_RPATH "\$ORIGIN")
endif ()
endif()

set(C_LINK_TEMPLATE "<CMAKE_C_COMPILER> <CMAKE_SHARED_LIBRARY_C_FLAGS> <LANGUAGE_COMPILE_FLAGS> <LINK_FLAGS> <CMAKE_SHARED_LIBRARY_CREATE_C_FLAGS> -o <TARGET> <OBJECTS> <LINK_LIBRARIES>")
set(CXX_LINK_TEMPLATE "<CMAKE_CXX_COMPILER> <CMAKE_SHARED_LIBRARY_CXX_FLAGS> <LANGUAGE_COMPILE_FLAGS> <LINK_FLAGS> <CMAKE_SHARED_LIBRARY_CREATE_CXX_FLAGS> -o <TARGET> <OBJECTS> <LINK_LIBRARIES>")

add_library(tblis_einsum SHARED as_einsum.cxx)

option(VENDOR_TBLIS "Download and build tblis" on)
set_target_properties(tblis_einsum PROPERTIES
LIBRARY_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR}
COMPILE_FLAGS "-std=c++11")
Expand Down
2 changes: 1 addition & 1 deletion pyscf/tblis_einsum/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = '0.1.5'

from .tblis_einsum import contract
from .tblis_einsum import contract, tensor_add, tensor_mult, tensor_dot
45 changes: 44 additions & 1 deletion pyscf/tblis_einsum/as_einsum.cxx
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <stddef.h>
#include <stdlib.h>
#include <tblis/tblis.h>
using namespace tblis;

Expand Down Expand Up @@ -42,7 +43,7 @@ static void _as_tensor(tblis_tensor *t, void *data, int dtype, int ndim,
}

extern "C" {
void as_einsum(void *data_A, int ndim_A, ptrdiff_t *shape_A, ptrdiff_t *strides_A, char *descr_A,
void tensor_mult(void *data_A, int ndim_A, ptrdiff_t *shape_A, ptrdiff_t *strides_A, char *descr_A,
void *data_B, int ndim_B, ptrdiff_t *shape_B, ptrdiff_t *strides_B, char *descr_B,
void *data_C, int ndim_C, ptrdiff_t *shape_C, ptrdiff_t *strides_C, char *descr_C,
int dtype, void *alpha, void *beta)
Expand All @@ -54,4 +55,46 @@ void as_einsum(void *data_A, int ndim_A, ptrdiff_t *shape_A, ptrdiff_t *strides_

tblis_tensor_mult(NULL, NULL, &A, descr_A, &B, descr_B, &C, descr_C);
}


void tensor_add(void *data_A, int ndim_A, ptrdiff_t *shape_A, ptrdiff_t *strides_A, char *descr_A,
void *data_B, int ndim_B, ptrdiff_t *shape_B, ptrdiff_t *strides_B, char *descr_B,
int dtype, void *alpha, void *beta)
{
tblis_tensor A, B;
_as_tensor(&A, data_A, dtype, ndim_A, shape_A, strides_A, alpha);
_as_tensor(&B, data_B, dtype, ndim_B, shape_B, strides_B, beta);

tblis_tensor_add(NULL, NULL, &A, descr_A, &B, descr_B);
}

void tensor_dot(void *data_A, int ndim_A, ptrdiff_t *shape_A, ptrdiff_t *strides_A, char *descr_A,
void *data_B, int ndim_B, ptrdiff_t *shape_B, ptrdiff_t *strides_B, char *descr_B,
int dtype, void *result)
{
tblis_tensor A, B;
tblis_scalar s;
_as_tensor(&A, data_A, dtype, ndim_A, shape_A, strides_A, NULL);
_as_tensor(&B, data_B, dtype, ndim_B, shape_B, strides_B, NULL);

tblis_tensor_dot(NULL, NULL, &A, descr_A, &B, descr_B, &s);

ssize_t bytes;
switch(dtype) {
case TYPE_SINGLE:
bytes = sizeof(float);
break;
case TYPE_DOUBLE:
bytes = sizeof(double);
break;
case TYPE_SCOMPLEX:
bytes = sizeof(scomplex);
break;
case TYPE_DCOMPLEX:
bytes = sizeof(dcomplex);
break;
}

memcpy(result, &s.data, bytes);
}
}
149 changes: 119 additions & 30 deletions pyscf/tblis_einsum/tblis_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

libtblis = numpy.ctypeslib.load_library('libtblis_einsum', os.path.dirname(__file__))

libtblis.as_einsum.restype = None
libtblis.as_einsum.argtypes = (
libtblis.tensor_mult.restype = None
libtblis.tensor_mult.argtypes = (
numpy.ctypeslib.ndpointer(), ctypes.c_int,
ctypes.POINTER(ctypes.c_size_t), ctypes.POINTER(ctypes.c_size_t),
ctypes.POINTER(ctypes.c_char),
Expand All @@ -35,6 +35,30 @@
numpy.ctypeslib.ndpointer(), numpy.ctypeslib.ndpointer()
)

libtblis.tensor_add.restype = None
libtblis.tensor_add.argtypes = (
numpy.ctypeslib.ndpointer(), ctypes.c_int,
ctypes.POINTER(ctypes.c_size_t), ctypes.POINTER(ctypes.c_size_t),
ctypes.POINTER(ctypes.c_char),
numpy.ctypeslib.ndpointer(), ctypes.c_int,
ctypes.POINTER(ctypes.c_size_t), ctypes.POINTER(ctypes.c_size_t),
ctypes.POINTER(ctypes.c_char),
ctypes.c_int,
numpy.ctypeslib.ndpointer(), numpy.ctypeslib.ndpointer()
)

libtblis.tensor_dot.restype = None
libtblis.tensor_dot.argtypes = (
numpy.ctypeslib.ndpointer(), ctypes.c_int,
ctypes.POINTER(ctypes.c_size_t), ctypes.POINTER(ctypes.c_size_t),
ctypes.POINTER(ctypes.c_char),
numpy.ctypeslib.ndpointer(), ctypes.c_int,
ctypes.POINTER(ctypes.c_size_t), ctypes.POINTER(ctypes.c_size_t),
ctypes.POINTER(ctypes.c_char),
ctypes.c_int,
numpy.ctypeslib.ndpointer()
)

tblis_dtype = {
numpy.dtype(numpy.float32) : 0,
numpy.dtype(numpy.double) : 1,
Expand All @@ -44,6 +68,29 @@

EINSUM_MAX_SIZE = 2000

def ctype_strides(*arrays):
return [ (ctypes.c_size_t*arr.ndim)(*[x//arr.dtype.itemsize for x in arr.strides]) for arr in arrays ]

def ctype_shapes(*arrays):
return [ (ctypes.c_size_t*arr.ndim)(*arr.shape) for arr in arrays ]

def check_tblis_shapes(a, a_inds, b, b_inds, subscripts=None, c_inds=None):
a_shape_dic = dict(zip(a_inds, a.shape))
b_shape_dic = dict(zip(b_inds, b.shape))
if subscripts is None:
subscripts = a_inds + ',' + b_inds
if any(a_shape_dic[x] != b_shape_dic[x]
for x in set(a_inds).intersection(b_inds)):
raise ValueError('operands dimension error for "%s" : %s %s'
% (subscripts, a.shape, b.shape))

if c_inds is not None:
ab_shape_dic = a_shape_dic
ab_shape_dic.update(b_shape_dic)
c_shape = tuple([ab_shape_dic[x] for x in c_inds])
return c_shape
return None

_numpy_einsum = numpy.einsum
def contract(subscripts, *tensors, **kwargs):
'''
Expand Down Expand Up @@ -87,45 +134,87 @@ def contract(subscripts, *tensors, **kwargs):
alpha = kwargs.get('alpha', 1)
beta = kwargs.get('beta', 0)
c_dtype = numpy.result_type(c_dtype, alpha, beta)
alpha = numpy.asarray(alpha, dtype=c_dtype)
beta = numpy.asarray(beta , dtype=c_dtype)

a = numpy.asarray(a, dtype=c_dtype)
b = numpy.asarray(b, dtype=c_dtype)
assert len(a_descr) == a.ndim
assert len(b_descr) == b.ndim
a_shape = a.shape
b_shape = b.shape
a_shape_dic = dict(zip(a_descr, a_shape))
b_shape_dic = dict(zip(b_descr, b_shape))
if any(a_shape_dic[x] != b_shape_dic[x]
for x in set(a_descr).intersection(b_descr)):
raise ValueError('operands dimension error for "%s" : %s %s'
% (subscripts, a_shape, b_shape))

ab_shape_dic = a_shape_dic
ab_shape_dic.update(b_shape_dic)
c_shape = tuple([ab_shape_dic[x] for x in c_descr])
c_shape = check_tblis_shapes(a, a_descr, b, b_descr, subscripts=subscripts, c_inds=c_descr)

out = kwargs.get('out', None)
if out is None:
order = kwargs.get('order', 'C')
c = numpy.empty(c_shape, dtype=c_dtype, order=order)
else:
assert(out.dtype == c_dtype)
assert(out.shape == c_shape)
c = out
return tensor_mult(a, a_descr, b, b_descr, c, c_descr, alpha=alpha, beta=beta, dtype=c_dtype)

def tensor_mult(a, a_inds, b, b_inds, c, c_inds, alpha=1, beta=0, dtype=None):
''' Wrapper for tblis_tensor_mult

Performs the einsum operation
c_{c_inds} = alpha * SUM[a_{a_inds} * b_{b_inds}] + beta * c_{c_inds}
where the sum is over indices in a_inds and b_inds that are not in c_inds.
'''

a_shape = (ctypes.c_size_t*a.ndim)(*a_shape)
b_shape = (ctypes.c_size_t*b.ndim)(*b_shape)
c_shape = (ctypes.c_size_t*c.ndim)(*c_shape)
if dtype is None:
dtype = c.dtype.type
assert dtype == c.dtype.type
assert dtype == a.dtype.type
assert dtype == b.dtype.type

nbytes = c_dtype.itemsize
a_strides = (ctypes.c_size_t*a.ndim)(*[x//nbytes for x in a.strides])
b_strides = (ctypes.c_size_t*b.ndim)(*[x//nbytes for x in b.strides])
c_strides = (ctypes.c_size_t*c.ndim)(*[x//nbytes for x in c.strides])
alpha = numpy.asarray(alpha, dtype=dtype)
beta = numpy.asarray(beta , dtype=dtype)

libtblis.as_einsum(a, a.ndim, a_shape, a_strides, a_descr.encode('ascii'),
b, b.ndim, b_shape, b_strides, b_descr.encode('ascii'),
c, c.ndim, c_shape, c_strides, c_descr.encode('ascii'),
tblis_dtype[c_dtype], alpha, beta)
assert len(a_inds) == a.ndim
assert len(b_inds) == b.ndim
assert len(c_inds) == c.ndim

a_shape, b_shape, c_shape = ctype_shapes(a, b, c)
a_strides, b_strides, c_strides = ctype_strides(a, b, c)
assert c.shape == check_tblis_shapes(a, a_inds, b, b_inds, c_inds=c_inds)


libtblis.tensor_mult(a, a.ndim, a_shape, a_strides, a_inds.encode('ascii'),
b, b.ndim, b_shape, b_strides, b_inds.encode('ascii'),
c, c.ndim, c_shape, c_strides, c_inds.encode('ascii'),
tblis_dtype[c.dtype], alpha, beta)
return c


def tensor_add(a, a_inds, b, b_inds, alpha=1, beta=1):
'''Wrapper for tblis_tensor_add
'''
assert a.dtype.type == b.dtype.type

alpha = numpy.asarray(alpha, dtype=b.dtype)
beta = numpy.asarray(beta , dtype=b.dtype)

assert len(a_inds) == a.ndim
assert len(b_inds) == b.ndim

a_shape, b_shape = ctype_shapes(a, b)
a_strides, b_strides = ctype_strides(a, b)

libtblis.tensor_add(a, a.ndim, a_shape, a_strides, a_inds.encode('ascii'),
b, b.ndim, b_shape, b_strides, b_inds.encode('ascii'),
tblis_dtype[b.dtype], alpha, beta)

def tensor_dot(a, a_inds, b, b_inds):
'''Wrapper for tblis_tensor_dot
'''

assert a.dtype.type == b.dtype.type

assert len(a_inds) == a.ndim
assert len(b_inds) == b.ndim

a_shape, b_shape = ctype_shapes(a, b)
a_strides, b_strides = ctype_strides(a, b)

result = numpy.zeros(1, dtype=a.dtype.type)

libtblis.tensor_dot(a, a.ndim, a_shape, a_strides, a_inds.encode('ascii'),
b, b.ndim, b_shape, b_strides, b_inds.encode('ascii'),
tblis_dtype[b.dtype], result)

return result[0]
26 changes: 25 additions & 1 deletion pyscf/tblis_einsum/tests/test_einsum.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
import numpy as np
from pyscf.tblis_einsum import tblis_einsum
from pyscf.tblis_einsum import tblis_einsum, tensor_add, tensor_mult, tensor_dot

def setUpModule():
global bak
Expand Down Expand Up @@ -129,6 +129,30 @@ def test_contraction6(self):
res_tblis = tblis_einsum.contract("ij,jk->k", c, c)
self.assertTrue (abs(res_np - res_tblis).max() < 1e-13)

def test_tensor_add(self):
a = np.random.random((3,3,3,3))
b = np.random.random((3,3,3,3))
c0 = np.transpose(a, (1,2,3,0)) + b
tensor_add(a, 'lijk', b, 'ijkl')
self.assertTrue(abs(c0-b).max() < 1e-14)

def test_tensor_add_scaled(self):
a = np.random.random((7,1,3,4))
b = np.random.random((7,1,3,4))
alpha = 2.0
beta = 3.0
c0 = alpha*a + beta*b
tensor_add(a, 'ijkl', b, 'ijkl', alpha=alpha, beta=beta)
self.assertTrue(abs(c0-b).max() < 1e-14)

def test_tensor_dot(self):
a = np.random.random((3,3,3,3))
b = np.random.random((3,3,3,3))
ans = np.einsum('lijk,ijkl->', a, b)
ans2 = tensor_dot(a, 'lijk', b, 'ijkl')
self.assertTrue(abs(ans-ans2) < 1e-14)



if __name__ == '__main__':
unittest.main()