Skip to content

Commit d320536

Browse files
committed
Use GADT to avoid closures
1 parent 5f3a39d commit d320536

File tree

6 files changed

+84
-77
lines changed

6 files changed

+84
-77
lines changed

src/kcas/kcas.ml

+73-52
Original file line numberDiff line numberDiff line change
@@ -685,83 +685,104 @@ module Xt = struct
685685
(* Fenceless is safe as we are accessing a private location. *)
686686
xt_r.mode == `Obstruction_free && 0 <= loc.id
687687

688-
let[@inline] update_new loc f xt lt gt =
689-
(* Fenceless is safe inside transactions as each log update has a fence. *)
688+
type (_, _) up =
689+
| Get : (unit, 'a) up
690+
| Fetch_and_add : (int, int) up
691+
| Exchange : ('a, 'a) up
692+
| Fn : ('a -> 'a, 'a) up
693+
| Compare_and_swap : ('a * 'a, 'a) up
694+
695+
let[@inline] update :
696+
type c a. 'x t -> a loc -> c -> (c, a) up -> _ -> _ -> a state -> a -> a =
697+
fun xt loc c up lt gt state before ->
698+
let after =
699+
match up with
700+
| Get -> before
701+
| Fetch_and_add -> before + c
702+
| Exchange -> c
703+
| Compare_and_swap -> if fst c == before then snd c else before
704+
| Fn -> begin
705+
let rot = !(tree_as_ref xt) in
706+
match c before with
707+
| after ->
708+
assert (rot == !(tree_as_ref xt));
709+
after
710+
| exception exn ->
711+
tree_as_ref xt := T (Node { loc; state; lt; gt; awaiters = [] });
712+
raise exn
713+
end
714+
in
715+
let state =
716+
if before == after && is_obstruction_free xt loc then state
717+
else { before; after; which = W xt; awaiters = [] }
718+
in
719+
tree_as_ref xt := T (Node { loc; state; lt; gt; awaiters = [] });
720+
before
721+
722+
let[@inline] update_new :
723+
type c a. 'x t -> a loc -> c -> (c, a) up -> _ -> _ -> a =
724+
fun xt loc c up lt gt ->
690725
let state = fenceless_get (as_atomic loc) in
691726
let before = eval state in
692-
match f before with
693-
| after ->
694-
let state =
695-
if before == after && is_obstruction_free xt loc then state
696-
else { before; after; which = W xt; awaiters = [] }
697-
in
698-
tree_as_ref xt := T (Node { loc; state; lt; gt; awaiters = [] });
699-
before
700-
| exception exn ->
701-
tree_as_ref xt := T (Node { loc; state; lt; gt; awaiters = [] });
702-
raise exn
727+
update xt loc c up lt gt state before
703728

704-
let[@inline] update_top loc f xt state' lt gt =
705-
let state = Obj.magic state' in
706-
if is_cmp xt state then begin
707-
let before = eval state in
708-
let after = f before in
709-
let state =
710-
if before == after then state
711-
else { before; after; which = W xt; awaiters = [] }
729+
let[@inline] update_top :
730+
type c a. 'x t -> a loc -> c -> (c, a) up -> _ -> _ -> _ -> a =
731+
fun xt loc c up lt gt state' ->
732+
let state : a state = Obj.magic state' in
733+
if is_cmp xt state then update xt loc c up lt gt state (eval state)
734+
else
735+
let before = state.after in
736+
let after =
737+
match up with
738+
| Get -> before
739+
| Fetch_and_add -> before + c
740+
| Exchange -> c
741+
| Compare_and_swap -> if fst c == before then snd c else before
742+
| Fn ->
743+
let rot = !(tree_as_ref xt) in
744+
let after = c before in
745+
assert (rot == !(tree_as_ref xt));
746+
after
712747
in
748+
let state = if before == after then state else { state with after } in
713749
tree_as_ref xt := T (Node { loc; state; lt; gt; awaiters = [] });
714750
before
715-
end
716-
else
717-
let current = state.after in
718-
let state = { state with after = f current } in
719-
tree_as_ref xt := T (Node { loc; state; lt; gt; awaiters = [] });
720-
current
721751

722-
let[@inline] unsafe_update ~xt loc f =
752+
let update_as ~xt loc c up =
723753
let loc = Loc.to_loc loc in
724754
maybe_validate_log xt;
725755
let x = loc.id in
726756
match !(tree_as_ref xt) with
727-
| T Leaf -> update_new loc f xt (T Leaf) (T Leaf)
757+
| T Leaf -> update_new xt loc c up (T Leaf) (T Leaf)
728758
| T (Node { loc = a; lt = T Leaf; _ }) as tree when x < a.id ->
729-
update_new loc f xt (T Leaf) tree
759+
update_new xt loc c up (T Leaf) tree
730760
| T (Node { loc = a; gt = T Leaf; _ }) as tree when a.id < x ->
731-
update_new loc f xt tree (T Leaf)
761+
update_new xt loc c up tree (T Leaf)
732762
| T (Node { loc = a; state; lt; gt; _ }) when Obj.magic a == loc ->
733-
update_top loc f xt state lt gt
763+
update_top xt loc c up lt gt state
734764
| tree -> begin
735765
match splay ~hit_parent:false x tree with
736-
| l, T Leaf, r -> update_new loc f xt l r
737-
| l, T (Node node_r), r -> update_top loc f xt node_r.state l r
766+
| l, T Leaf, r -> update_new xt loc c up l r
767+
| l, T (Node node_r), r -> update_top xt loc c up l r node_r.state
738768
end
739769

740-
let[@inline] protect xt f x =
741-
let tree = !(tree_as_ref xt) in
742-
let y = f x in
743-
assert (!(tree_as_ref xt) == tree);
744-
y
745-
746-
let get ~xt loc = unsafe_update ~xt loc Fun.id
747-
let set ~xt loc after = unsafe_update ~xt loc (fun _ -> after) |> ignore
748-
let modify ~xt loc f = unsafe_update ~xt loc (protect xt f) |> ignore
770+
let get ~xt loc = update_as ~xt loc () Get
771+
let set ~xt loc after = update_as ~xt loc after Exchange |> ignore
772+
let modify ~xt loc f = update_as ~xt loc f Fn |> ignore
749773

750774
let compare_and_swap ~xt loc before after =
751-
unsafe_update ~xt loc (fun actual ->
752-
if actual == before then after else actual)
775+
update_as ~xt loc (before, after) Compare_and_swap
753776

754777
let compare_and_set ~xt loc before after =
755778
compare_and_swap ~xt loc before after == before
756779

757-
let exchange ~xt loc after = unsafe_update ~xt loc (fun _ -> after)
758-
let fetch_and_add ~xt loc n = unsafe_update ~xt loc (( + ) n)
759-
let incr ~xt loc = unsafe_update ~xt loc inc |> ignore
760-
let decr ~xt loc = unsafe_update ~xt loc dec |> ignore
761-
let update ~xt loc f = unsafe_update ~xt loc (protect xt f)
780+
let exchange ~xt loc after = update_as ~xt loc after Exchange
781+
let fetch_and_add ~xt loc n = update_as ~xt loc n Fetch_and_add
782+
let incr ~xt loc = update_as ~xt loc 1 Fetch_and_add |> ignore
783+
let decr ~xt loc = update_as ~xt loc (-1) Fetch_and_add |> ignore
784+
let update ~xt loc f = update_as ~xt loc f Fn
762785
let swap ~xt l1 l2 = set ~xt l1 @@ exchange ~xt l2 @@ get ~xt l1
763-
let unsafe_modify ~xt loc f = unsafe_update ~xt loc f |> ignore
764-
let unsafe_update ~xt loc f = unsafe_update ~xt loc f
765786

766787
let[@inline] to_blocking ~xt tx =
767788
match tx ~xt with None -> Retry.later () | Some value -> value

src/kcas/kcas.mli

-10
Original file line numberDiff line numberDiff line change
@@ -558,14 +558,4 @@ module Xt : sig
558558
The default {{!Mode.t} [mode]} for [commit] is [`Obstruction_free].
559559
However, after enough attempts have failed during the verification step,
560560
[commit] switches to [`Lock_free]. *)
561-
562-
(**/**)
563-
564-
val unsafe_modify : xt:'x t -> 'a Loc.t -> ('a -> 'a) -> unit
565-
(** [unsafe_modify ~xt r f] is equivalent to [modify ~xt r f], but does not
566-
assert against misuse. *)
567-
568-
val unsafe_update : xt:'x t -> 'a Loc.t -> ('a -> 'a) -> 'a
569-
(** [unsafe_update ~xt r f] is equivalent to [update ~xt r f], but does not
570-
assert against misuse. *)
571561
end

src/kcas_data/hashtbl.ml

+4-4
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ module Xt = struct
219219
Array.unsafe_get old_buckets i
220220
|> Xt.get ~xt
221221
|> Assoc.iter_rev @@ fun k v ->
222-
Xt.unsafe_modify ~xt
222+
Xt.modify ~xt
223223
(Array.unsafe_get new_buckets (hash k land mask))
224224
(Assoc.cons k v)
225225
done
@@ -337,7 +337,7 @@ module Xt = struct
337337
let buckets = r.buckets in
338338
let mask = Array.length buckets - 1 in
339339
let bucket = Array.unsafe_get buckets (r.hash k land mask) in
340-
match Xt.unsafe_modify ~xt bucket (Assoc.remove r.equal k) with
340+
match Xt.modify ~xt bucket (Assoc.remove r.equal k) with
341341
| () ->
342342
Accumulator.Xt.decr ~xt r.length;
343343
if r.min_buckets <= mask && Random.bits () land mask = 0 then
@@ -353,7 +353,7 @@ module Xt = struct
353353
let buckets = r.buckets in
354354
let mask = Array.length buckets - 1 in
355355
let bucket = Array.unsafe_get buckets (r.hash k land mask) in
356-
Xt.unsafe_modify ~xt bucket (Assoc.cons k v);
356+
Xt.modify ~xt bucket (Assoc.cons k v);
357357
Accumulator.Xt.incr ~xt r.length;
358358
if mask + 1 < r.max_buckets && Random.bits () land mask = 0 then
359359
let capacity = mask + 1 in
@@ -367,7 +367,7 @@ module Xt = struct
367367
let mask = Array.length buckets - 1 in
368368
let bucket = Array.unsafe_get buckets (r.hash k land mask) in
369369
let change = ref Assoc.Nop in
370-
Xt.unsafe_modify ~xt bucket (fun kvs ->
370+
Xt.modify ~xt bucket (fun kvs ->
371371
let kvs' = Assoc.replace r.equal change k v kvs in
372372
if !change != Assoc.Nop then kvs' else kvs);
373373
if !change == Assoc.Added then begin

src/kcas_data/mvar.ml

+2-3
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,13 @@ module Xt = struct
1111
Magic_option.is_none
1212
(Xt.compare_and_swap ~xt mv Magic_option.none (Magic_option.some value))
1313

14-
let put ~xt mv value =
15-
Xt.unsafe_modify ~xt mv (Magic_option.put_or_retry value)
14+
let put ~xt mv value = Xt.modify ~xt mv (Magic_option.put_or_retry value)
1615

1716
let take_opt ~xt mv =
1817
Magic_option.to_option (Xt.exchange ~xt mv Magic_option.none)
1918

2019
let take ~xt mv =
21-
Magic_option.get_unsafe (Xt.unsafe_update ~xt mv Magic_option.take_or_retry)
20+
Magic_option.get_unsafe (Xt.update ~xt mv Magic_option.take_or_retry)
2221

2322
let peek ~xt mv = Magic_option.get_or_retry (Xt.get ~xt mv)
2423
let peek_opt ~xt mv = Magic_option.to_option (Xt.get ~xt mv)

src/kcas_data/queue.ml

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ module Xt = struct
3434
+ Elems.length (Xt.get ~xt middle)
3535
+ Elems.length (Xt.get ~xt back)
3636

37-
let add ~xt x q = Xt.unsafe_modify ~xt q.back @@ Elems.cons x
37+
let add ~xt x q = Xt.modify ~xt q.back @@ Elems.cons x
3838
let push = add
3939

4040
(** Cooperative helper to move elems from back to middle. *)
@@ -53,7 +53,7 @@ module Xt = struct
5353

5454
let take_opt ~xt t =
5555
let front = t.front in
56-
let elems = Xt.unsafe_update ~xt front Elems.tl_safe in
56+
let elems = Xt.update ~xt front Elems.tl_safe in
5757
if elems != Elems.empty then Elems.hd_opt elems
5858
else
5959
let middle = t.middle and back = t.back in

src/kcas_data/stack.ml

+3-6
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,10 @@ let of_seq xs = Loc.make ~padded:true (Elems.of_seq_rev xs)
99
module Xt = struct
1010
let length ~xt s = Xt.get ~xt s |> Elems.length
1111
let is_empty ~xt s = Xt.get ~xt s == Elems.empty
12-
let push ~xt x s = Xt.unsafe_modify ~xt s @@ Elems.cons x
13-
let pop_opt ~xt s = Xt.unsafe_update ~xt s Elems.tl_safe |> Elems.hd_opt
12+
let push ~xt x s = Xt.modify ~xt s @@ Elems.cons x
13+
let pop_opt ~xt s = Xt.update ~xt s Elems.tl_safe |> Elems.hd_opt
1414
let pop_all ~xt s = Elems.to_seq @@ Xt.exchange ~xt s Elems.empty
15-
16-
let pop_blocking ~xt s =
17-
Xt.unsafe_update ~xt s Elems.tl_safe |> Elems.hd_or_retry
18-
15+
let pop_blocking ~xt s = Xt.update ~xt s Elems.tl_safe |> Elems.hd_or_retry
1916
let top_opt ~xt s = Xt.get ~xt s |> Elems.hd_opt
2017
let top_blocking ~xt s = Xt.get ~xt s |> Elems.hd_or_retry
2118
let clear ~xt s = Xt.set ~xt s Elems.empty

0 commit comments

Comments
 (0)