Skip to content

Starting point for a modern fortran module #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 244 additions & 0 deletions src/symengine.f08
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
module symengine

use iso_c_binding, only: c_size_t, c_long, c_char, c_ptr, c_null_ptr, c_null_char, c_f_pointer, c_associated
implicit none

interface
function c_strlen(string) bind(C, name="strlen")
import :: c_size_t, c_ptr
type(c_ptr), intent(in), value :: string
integer(kind=c_size_t) :: c_strlen
end function c_strlen
function c_basic_new_heap() bind(c, name='basic_new_heap')
import :: c_ptr
type(c_ptr) :: c_basic_new_heap
end function
subroutine c_basic_free_heap(s) bind(c, name='basic_free_heap')
import :: c_ptr
type(c_ptr), value :: s
end subroutine
function c_basic_assign(a, b) bind(c, name='basic_assign')
import c_long, c_ptr
type(c_ptr), value :: a, b
integer(c_long) :: c_basic_assign
end function
function c_basic_str(s) bind(c, name='basic_str')
import :: c_ptr
type(c_ptr), value :: s
type(c_ptr) :: c_basic_str
end function
function c_basic_parse(s, c) bind(c, name='basic_parse')
import c_long, c_ptr, c_char
type(c_ptr), value :: s
character(kind=c_char), dimension(*) :: c
integer(c_long) :: c_basic_parse
end function
subroutine c_basic_str_free(s) bind(c, name='basic_str_free')
import :: c_ptr
type(c_ptr), value :: s
end subroutine
function c_basic_add(s, a, b) bind(c, name='basic_add')
import :: c_long, c_ptr
type(c_ptr), value :: s, a, b
integer(c_long) :: c_basic_add
end function
function c_basic_sub(s, a, b) bind(c, name='basic_sub')
import :: c_long, c_ptr
type(c_ptr), value :: s, a, b
integer(c_long) :: c_basic_sub
end function
function c_basic_mul(s, a, b) bind(c, name='basic_mul')
import :: c_long, c_ptr
type(c_ptr), value :: s, a, b
integer(c_long) :: c_basic_mul
end function
function c_basic_div(s, a, b) bind(c, name='basic_div')
import :: c_long, c_ptr
type(c_ptr), value :: s, a, b
integer(c_long) :: c_basic_div
end function
function c_basic_pow(s, a, b) bind(c, name='basic_pow')
import :: c_long, c_ptr
type(c_ptr), value :: s, a, b
integer(c_long) :: c_basic_pow
end function
function c_integer_set_si(s, i) bind(c, name='integer_set_si')
import :: c_long, c_ptr
type(c_ptr), value :: s
integer(c_long), value :: i
integer(c_long) :: c_integer_set_si
end function
function c_integer_get_si(s) bind(c, name='integer_get_si')
import c_long, c_ptr
type(c_ptr), value :: s
integer(c_long) :: c_integer_get_si
end function
function c_symbol_set(s, c) bind(c, name='symbol_set')
import c_long, c_ptr, c_char
type(c_ptr), value :: s
character(kind=c_char), dimension(*) :: c
integer(c_long) :: c_symbol_set
end function
end interface


type Basic
type(c_ptr) :: ptr = c_null_ptr
logical :: tmp = .false.
contains
procedure :: str, basic_assign, basic_add, basic_sub, basic_mul, basic_div, basic_pow
generic :: assignment(=) => basic_assign
generic :: operator(+) => basic_add
generic :: operator(-) => basic_sub
generic :: operator(*) => basic_mul
generic :: operator(/) => basic_div
generic :: operator(**) => basic_pow
final :: basic_free
end type

interface Basic
module procedure basic_new
end interface

type, extends(Basic) :: SymInteger
contains
procedure :: get
end type SymInteger

interface SymInteger
module procedure integer_new
end interface

type, extends(Basic) :: Symbol
end type Symbol

interface Symbol
module procedure symbol_new
end interface


contains


function basic_new() result(new)
type(Basic) :: new
new%ptr = c_basic_new_heap()
end function

subroutine basic_free(this)
type(Basic) :: this
call c_basic_free_heap(this%ptr)
end subroutine

function str(e)
class(Basic) :: e
character, pointer, dimension(:) :: tempstr
character(:), allocatable :: str
type(c_ptr) :: cstring
integer :: nchars
cstring = c_basic_str(e%ptr)
nchars = c_strlen(cstring)
call c_f_pointer(cstring, tempstr, [nchars])
allocate(character(len=nchars) :: str)
str = transfer(tempstr(1:nchars), str)
call c_basic_str_free(cstring)
end function

subroutine basic_assign(a, b)
class(basic), intent(inout) :: a
class(basic), intent(in) :: b
integer(c_long) :: dummy
if (.not. c_associated(a%ptr)) then
a%ptr = c_basic_new_heap()
end if
dummy = c_basic_assign(a%ptr, b%ptr)
if (b%tmp) then
call basic_free(b)
end if
end subroutine

function basic_add(a, b)
class(basic), intent(in) :: a, b
type(basic) :: basic_add
integer(c_long) :: dummy
basic_add = Basic()
dummy = c_basic_add(basic_add%ptr, a%ptr, b%ptr)
basic_add%tmp = .true.
end function

function basic_sub(a, b)
class(basic), intent(in) :: a, b
type(basic) :: basic_sub
integer(c_long) :: dummy
basic_sub = Basic()
dummy = c_basic_sub(basic_sub%ptr, a%ptr, b%ptr)
basic_sub%tmp = .true.
end function

function basic_mul(a, b)
class(basic), intent(in) :: a, b
type(basic) :: basic_mul
integer(c_long) :: dummy
basic_mul = Basic()
dummy = c_basic_mul(basic_mul%ptr, a%ptr, b%ptr)
basic_mul%tmp = .true.
end function

function basic_div(a, b)
class(basic), intent(in) :: a, b
type(basic) :: basic_div
integer(c_long) :: dummy
basic_div = Basic()
dummy = c_basic_div(basic_div%ptr, a%ptr, b%ptr)
basic_div%tmp = .true.
end function

function basic_pow(a, b)
class(basic), intent(in) :: a, b
type(basic) :: basic_pow
integer(c_long) :: dummy
basic_pow = Basic()
dummy = c_basic_pow(basic_pow%ptr, a%ptr, b%ptr)
basic_pow%tmp = .true.
end function

function integer_new(i)
integer :: i
integer(c_long) :: j
integer(c_long) :: dummy
type(SymInteger) :: integer_new
j = int(i)
integer_new%ptr = c_basic_new_heap()
dummy = c_integer_set_si(integer_new%ptr, j)
integer_new%tmp = .true.
end function

function get(this) result(i)
class(SymInteger) :: this
integer :: i
i = int(c_integer_get_si(this%ptr))
end function

function symbol_new(c)
character(len=*) :: c
character(len=len_trim(c) + 1) :: new_c
integer(c_long) :: dummy
type(Symbol) :: symbol_new
new_c = trim(c) // c_null_char
symbol_new%ptr = c_basic_new_heap()
symbol_new%tmp = .true.
dummy = c_symbol_set(symbol_new%ptr, new_c)
end function

function parse(c)
character(len=*) :: c
type(Basic) :: parse
integer(c_long) :: dummy
character(len=len_trim(c) + 1) :: new_c
new_c = trim(c) // c_null_char
parse%ptr = c_basic_new_heap()
dummy = c_basic_parse(parse%ptr, new_c)
parse%tmp = .true.
end function

end module
23 changes: 23 additions & 0 deletions src/test.f08
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
subroutine dostuff()
use symengine
type(Basic) :: a, b, c

a = SymInteger(12)
b = Symbol('x')
c = a * b
print *, c%str()
c = parse('2*(24+x)')
print *, c%str()
end subroutine



program test

implicit none

call dostuff

print *, "Finishing"

end program