1 (* This module implements association maps using height-balanced trees.
   2    The code is modeled after OCaml's [Map] library, but has been modified
   3    to allow trees to be modified in place. *)
   4 
   5 (* -------------------------------------------------------------------------- *)
   6 
   7 (* Some stuff that should be moved to another file. *) (* TEMPORARY *)
   8 
   9 val max (x: int, y: int) : int =
  10   if x >= y then x else y
  11 
  12 (* -------------------------------------------------------------------------- *)
  13 
  14 (* A tree is either empty or a binary node. Besides its children, a binary
  15    node contains a key, a value, and its height. *)
  16 
  17 mutable data tree k a =
  18   | Empty
  19   | Node { left: tree k a; key: k; value: a; right: tree k a; height: int }
  20 
  21 (* -------------------------------------------------------------------------- *)
  22 
  23 (* Cardinal. *)
  24 
  25 val rec cardinal [k, a] (t: tree k a) : int =
  26   match t with
  27   | Empty ->
  28       0
  29   | Node ->
  30       cardinal t.left + 1 + cardinal t.right
  31   end
  32 
  33 (* -------------------------------------------------------------------------- *)
  34 
  35 (* The following (private) function reads the height information that is
  36    stored in a node. It does not recompute anything. *)
  37 
  38 val height [k, a] (t: tree k a) : int =
  39   match t with
  40   | Empty -> 0
  41   | Node  -> t.height
  42   end
  43 
  44 (* The following (private) function updates the height information that is
  45    stored in a node, based on the height of its children. *)
  46 
  47 val update_height [k, a] (
  48   consumes t : Node { left: tree k a; key: k; value: a; right: tree k a; height: unknown }
  49 ) : ( |    t @ Node { left: tree k a; key: k; value: a; right: tree k a; height: int     }) =
  50   t.height <- max (height t.left, height t.right) + 1
  51 
  52 (* The following (private) function re-organizes a tree, if required, by
  53    performing a rotation at the root. The left and right sub-trees are
  54    expected to have almost equal heights. The address of the new tree root
  55    is returned. *)
  56 
  57 val bal [k, a] (
  58   consumes t: Node { left: tree k a; key: k; value: a; right: tree k a; height: unknown }
  59 ) : tree k a =
  60 
  61   (* Extract the two sub-trees and their heights. *)
  62   let Node { left = l; right = r } = t in
  63   let hl = height l
  64   and hr = height r in
  65 
  66   (* Determine whether the tree is unbalanced and needs to be repaired. *)
  67 
  68   (* Situation 1: the left sub-tree is too big. *)
  69   if hl > hr + 2 then match l with
  70   | Empty -> fail (* impossible! *)
  71   | Node { left = ll; right = lr } ->
  72       if height ll >= height lr then begin
  73         (* The left node becomes the root node. *)
  74         (* The root node becomes the right child. *)
  75         t.left <- lr;
  76         update_height t;
  77         l.right <- t;
  78         update_height l;
  79         l
  80       end
  81       else match lr with
  82       | Empty -> fail (* impossible! *)
  83       | Node { left = lrl; right = lrr } ->
  84           (* The node [lr] becomes the root node. *)
  85           (* The root node becomes the right child. *)
  86           (* The left node remains the left child. *)
  87           l.right <- lrl;
  88           update_height l;
  89           t.left <- lrr;
  90           update_height t;
  91           lr.left <- l;
  92           lr.right <- t;
  93           update_height lr;
  94           lr
  95       end
  96   end
  97 
  98   (* Situation 2: the right sub-tree is too big. *)
  99   else if hr > hl + 2 then match r with
 100   | Empty -> fail (* impossible! *)
 101   | Node { left = rl; right = rr } ->
 102       if height rr >= height rl then begin
 103         (* The right node becomes the root node. *)
 104         (* The root node becomes the left child. *)
 105         t.right <- rl;
 106         update_height t;
 107         r.left <- t;
 108         update_height r;
 109         r
 110       end
 111       else match rl with
 112       | Empty -> fail (* impossible! *)
 113       | Node { left = rll; right = rlr } ->
 114           (* The node [rl] becomes the root node. *)
 115           (* The root node becomes the left child. *)
 116           (* The right node remains the right child. *)
 117           t.right <- rll;
 118           update_height t;
 119           r.left <- rlr;
 120           update_height r;
 121           rl.left <- t;
 122           rl.right <- r;
 123           update_height rl;
 124           rl
 125       end
 126   end
 127 
 128   (* Last situation: the tree is not unbalanced. *)
 129   (* Just update its height field. [t] remains the root. *)
 130   else begin
 131     t.height <- max (hl, hr) + 1;
 132     t
 133   end
 134 
 135 (* -------------------------------------------------------------------------- *)
 136 
 137 (* Creating an empty tree. *)
 138 
 139 val create [k, a] () : tree k a =
 140   Empty
 141 
 142 (* Creating a singleton tree. *)
 143 
 144 val singleton [k, a] (consumes (x: k, d: a)) : tree k a =
 145   Node { left = Empty; key = x; value = d; right = Empty; height = 1 }
 146 
 147 (* Testing whether a tree is empty. *)
 148 
 149 val is_empty [k, a] (t : tree k a) : bool =
 150   match t with
 151   | Empty -> True
 152   | Node  -> False
 153   end
 154 
 155 (* -------------------------------------------------------------------------- *)
 156 
 157 (* Insertion. *)
 158 
 159 val rec add [k, a] (
 160   cmp: (k, k) -> int,
 161   consumes x: k,
 162   consumes d: a,
 163   consumes t: tree k a
 164 ) : tree k a =
 165   match t with
 166   | Empty ->
 167       (* Create a singleton tree. *)
 168       Node { left = t; key = x; value = d; right = Empty; height = 1 }
 169   | Node ->
 170       let c = cmp (x, t.key) in
 171       if c = 0 then begin
 172         (* The key already exists; overwrite the previous data *)
 173         t.value <- d;
 174         t
 175       end
 176       else if c < 0 then begin
 177         t.left <- add (cmp, x, d, t.left);
 178         bal t
 179       end
 180       else begin
 181         t.right <- add (cmp, x, d, t.right);
 182         bal t
 183       end
 184   end
 185 
 186 (* -------------------------------------------------------------------------- *)
 187 
 188 (* Lookup. *)
 189 
 190 (* It seems that the function [find] must require [duplicable a].
 191    Indeed, without this hypothesis, we would be forced to consume
 192    the argument tree [t], which does not seem reasonable. *)
 193 
 194 val rec find [k, a] duplicable a => (
 195   cmp: (k, k) -> int,
 196   x: k,
 197   t: tree k a
 198 ) : option a =
 199   match t with
 200   | Empty ->
 201       none
 202   | Node ->
 203       let c = cmp (x, t.key) in
 204       if c = 0 then some t.value
 205       (* It is interesting to note that we cannot write the more compact code:
 206          find (cmp, x, (if c < 0 then t.left else t.right))
 207          Indeed, the type-checker is unable to figure out the desired type of
 208          the conditional sub-expression; it reports a resource allocation
 209          conflict. In fact, if we wanted to explicitly declare this type,
 210          I believe that we would need a magic wand: this sub-expression
 211          produces a result [s] together with the permissions [s @ tree k a]
 212          and [s @ tree k a -* t @ tree k a]. *)
 213       else if c < 0 then find (cmp, x, t.left)
 214       else find (cmp, x, t.right)
 215   end
 216 
 217 (* The above [find] function requires [a] to be duplicable. Another approach
 218    is to parameterize [find] with a [copy] function that is able to copy an
 219    element of type [a]. In fact, an even more general idea is to offer an
 220    [update] function that allows the caller to access the value found at the
 221    key [x] within a lexically-delimited scope, and then to surrender it (or
 222    a new version of it). *)
 223 
 224 (* Because the key [x] may be absent, the function [f] is called either never
 225    or just once. Our use of a [preserved/consumed] permission pair allows
 226    reflecting this. A [pre/post] permission pair would be more precise, but
 227    can be used only when it is known that [f] will be called exactly once. *)
 228 
 229 val rec update
 230   [k, a, preserved : perm, consumed : perm]
 231   (cmp: (k, k) -> int,
 232     x: k, t: tree k a,
 233     f: (consumes a | preserved * consumes consumed) -> a
 234       | preserved * consumes consumed
 235   ) : () =
 236   match t with
 237   | Empty ->
 238       ()
 239   | Node ->
 240       let c = cmp (x, t.key) in
 241       if c = 0 then
 242         t.value <- f t.value
 243       else if c < 0 then
 244         update [k, a, preserved, consumed] (cmp, x, t.left, f)
 245         (* WISH: get rid of the above type application *)
 246       else
 247         update [k, a, preserved, consumed] (cmp, x, t.right, f)
 248   end
 249 
 250 (* The following two functions (currently not exported) show that versions
 251    of [find] can be implemented in terms of [update]. *)
 252 
 253 val find_and_copy [k, a] (
 254   copy: a -> a,
 255   cmp: (k, k) -> int,
 256   x: k,
 257   t: tree k a
 258 ) : option a =
 259   let r = newref none in
 260   update [k, a, (r @ ref (option a)), empty] (cmp, x, t, fun (consumes x: a | r @ ref (option a)) : a =
 261     r := some (copy x);
 262     x
 263   );
 264   !r
 265 
 266 val find_variant [k, a] duplicable a => (
 267   cmp: (k, k) -> int,
 268   x: k,
 269   t: tree k a
 270 ) : option a =
 271   let id (x: a) : a = x in
 272   find_and_copy (id, cmp, x, t)
 273 
 274 val rec mem [k, a] (cmp: (k, k) -> int, x: k, t: tree k a) : bool =
 275   match t with
 276   | Empty ->
 277       False
 278   | Node ->
 279       let c = cmp (x, t.key) in
 280       if c = 0 then
 281         True
 282       else if c < 0 then
 283         mem (cmp, x, t.left)
 284       else
 285         mem (cmp, x, t.right)
 286  end
 287 
 288 (* -------------------------------------------------------------------------- *)
 289 
 290 (* Minimum and maximum elements. *)
 291 
 292 (* Because [min_binding] returns a binding but does not remove it from the
 293    tree, it is restricted to duplicable keys and values. *)
 294 
 295 (* [min_binding] is defined first for non-empty trees, then extended to empty
 296    trees. *)
 297 
 298 val rec min_binding
 299   [k, a] duplicable k => duplicable a =>
 300   (t : Node { left: tree k a; key: k; value: a; right: tree k a; height: int })
 301   : (k, a) =
 302   match t.left with
 303   | Empty ->
 304       t.key, t.value
 305   | Node ->
 306       min_binding t.left
 307   end
 308 
 309 val min_binding
 310   [k, a] duplicable k => duplicable a =>
 311   (t : tree k a)
 312   : option (k, a) =
 313   match t with
 314   | Empty ->
 315       none
 316   | Node ->
 317       some (min_binding t)
 318   end
 319 
 320 val rec max_binding
 321   [k, a] duplicable k => duplicable a =>
 322   (t : Node { left: tree k a; key: k; value: a; right: tree k a; height: int })
 323   : (k, a) =
 324   match t.right with
 325   | Empty ->
 326       t.key, t.value
 327   | Node ->
 328       max_binding t.right
 329   end
 330 
 331 val max_binding
 332   [k, a] duplicable k => duplicable a =>
 333   (t : tree k a)
 334   : option (k, a) =
 335   match t with
 336   | Empty ->
 337       none
 338   | Node ->
 339       some (max_binding t)
 340   end
 341 
 342 (* [extract_min_binding] extracts the node that contains the minimum key.
 343    It returns both this node (which can be re-used) and the remaining,
 344    re-organized tree. By convention, instead of returning a pair, we
 345    return a single node, which contains the minimum key, and whose
 346    right child is the remaining tree. *)
 347 
 348 val rec extract_min_binding
 349   [k, a]
 350   (consumes t : Node { left: tree k a; key: k; value: a; right: tree k a; height: int })
 351   :             Node { left:    Empty; key: k; value: a; right: tree k a; height: int }
 352   =
 353   match t.left with
 354   | Empty ->
 355       (* The desired node is [t], and the sub-tree [t.right] is what remains. *)
 356       t
 357   | Node ->
 358       (* Extract the minimum node out of the left sub-tree. *)
 359       let node = extract_min_binding t.left in
 360       (* Update in place the left sub-tree. *)
 361       t.left <- node.right;
 362       (* Perform a rotation at the root if required, and return. *)
 363       node.right <- bal t;
 364       node
 365   end
 366 
 367 val rec extract_max_binding
 368   [k, a]
 369   (consumes t : Node { left: tree k a; key: k; value: a; right: tree k a; height: int })
 370   :             Node { left: tree k a; key: k; value: a; right:    Empty; height: int }
 371   =
 372   match t.right with
 373   | Empty ->
 374       t
 375   | Node ->
 376       let node = extract_max_binding t.right in
 377       t.right <- node.left;
 378       node.left <- bal t;
 379       node
 380   end
 381 
 382 (* The private function [add_min_binding] takes a tree node whose only
 383    relevant fields are [key] and [value]. The [left] field is supposed
 384    to contain [Empty]. The [right] and [height] fields are irrelevant.
 385    This node is inserted into the tree [t], where it is expected to
 386    become the new minimum node. *)
 387 
 388 val rec add_min_binding [k, a]
 389   (consumes node: Node { left: Empty; key: k; value: a; right: unknown; height: unknown },
 390    consumes t: tree k a) : tree k a =
 391   match t with
 392   | Empty ->
 393       (* Turn [node] into a singleton tree. *)
 394       node.right <- t; (* re-use the memory block at [t], which is [Empty] *)
 395       node.height <- 1;
 396       node
 397   | Node ->
 398       t.left <- add_min_binding (node, t.left);
 399       bal t
 400   end
 401 
 402 val rec add_max_binding [k, a]
 403   (consumes node: Node { left: unknown; key: k; value: a; right: Empty; height: unknown },
 404    consumes t: tree k a) : tree k a =
 405   match t with
 406   | Empty ->
 407       (* Turn [node] into a singleton tree. *)
 408       node.left <- t; (* re-use the memory block at [t], which is [Empty] *)
 409       node.height <- 1;
 410       node
 411   | Node ->
 412       t.right <- add_max_binding (node, t.right);
 413       bal t
 414   end
 415 
 416 (* -------------------------------------------------------------------------- *)
 417 
 418 (* Removal. *)
 419 
 420 (* The private function [merge] combines two trees that have almost equal
 421    heights. *)
 422 
 423 val merge [k, a] (consumes t1: tree k a, consumes t2: tree k a) : tree k a =
 424   match t1, t2 with
 425   | Empty, t -> t
 426   | t, Empty -> t
 427   | Node, Node ->
 428       (* Extract the minimum node out of [t2]... *)
 429       let root = extract_min_binding t2 in
 430       (* And re-use this node to become the new root. *)
 431       root.left <- t1;
 432       bal root
 433   end
 434 
 435 (* Compared with OCaml's [remove], our [remove] function combines [find]
 436    and [remove]. The binding that was removed is returned via a reference,
 437    whereas the new tree is returned as a function result. *)
 438 
 439 (* [dst] is an out-parameter, just like in C, except here, we get a type
 440    error if we forget to write it! *)
 441 
 442 (* TEMPORARY we could avoid writes and calls to [bal] when nothing is
 443    removed *)
 444 
 445 val rec remove [k, a] (
 446   cmp: (k, k) -> int,
 447   x: k,
 448   consumes t: tree k a,
 449   consumes dst: ref unknown
 450 ) : (tree k a | dst @ ref (option (k, a)))
 451   =
 452   match t with
 453   | Empty ->
 454       dst := none;
 455       t
 456   | Node ->
 457       let c = cmp (x, t.key) in
 458       if c = 0 then begin
 459         dst := some [(k, a)] (t.key, t.value);
 460         merge (t.left, t.right)
 461       end
 462       else if c < 0 then begin
 463         t.left <- remove (cmp, x, t.left, dst);
 464         bal t
 465       end
 466       else begin
 467         t.right <- remove (cmp, x, t.right, dst);
 468         bal t
 469       end
 470   end
 471 
 472 (* -------------------------------------------------------------------------- *)
 473 
 474 (* Iteration. *)
 475 
 476 (* Compared with OCaml's [iter], our [iter] function is generalized
 477    to allow early termination. The client function, [f], is allowed
 478    to return a Boolean flag, which indicates whether iteration should
 479    continue. The function [iter] itself returns a Boolean outcome
 480    which indicates whether iteration was performed all the way to the
 481    end. *)
 482 
 483 (* This feature implies that the type of the collection elements cannot
 484    be modified. Our version of [fold] (below) makes the converse choices:
 485    early termination is not possible, but the type of the elements can be
 486    changed from [a1] to [a2]. *)
 487 
 488 (* In fact, our [iter] is exactly OCaml's [for_all]. (Except that the
 489    evaluation order is not the same; OCaml's [iter] guarantees that
 490    the keys are visited in increasing order, where OCaml's [for_all]
 491    does not.) Funny -- I never explicitly thought of [for_all] as a
 492    version of [iter] that has an early termination feature. *)
 493 
 494 val rec iter [k, a, p : perm] (
 495   f: (k,    a | p) -> bool,
 496   t: tree k a | p)  : bool =
 497   match t with
 498   | Empty ->
 499       True
 500   | Node ->
 501       iter (f, t.left) && f (t.key, t.value) && iter (f, t.right)
 502   end
 503 
 504 (* -------------------------------------------------------------------------- *)
 505 
 506 (* Map. *)
 507 
 508 (* Our [map] is modeled after OCaml's [mapi]. One could in fact offer
 509    even more generality by allowing keys to be copied/translated, as
 510    long as the key ordering is preserved. *)
 511 
 512 val rec map [k, a1, a2, b, p : perm] duplicable k => (
 513   f: (k, consumes d: a1 | p) -> (       b | d @        a2),
 514   consumes t: tree k a1 | p)  : (tree k b | t @ tree k a2) =
 515   match t with
 516   | Empty ->
 517       Empty
 518   | Node ->
 519       Node {
 520         left = map (f, t.left);
 521         key = t.key;
 522         value = f (t.key, t.value);
 523         right = map (f, t.right);
 524         height = t.height
 525       }
 526   end
 527 
 528 (* -------------------------------------------------------------------------- *)
 529 
 530 (* Fold. *)
 531 
 532 (* The two [fold] functions have the same type, but differ in the order
 533    in which the tree is visited. *)
 534 
 535 val rec fold_ascending [k, a1, a2, b, p : perm] (
 536   f: (k, consumes d:        a1, consumes accu: b | p) -> (b | d @        a2),
 537          consumes t: tree k a1, consumes accu: b | p)  : (b | t @ tree k a2) =
 538   match t with
 539   | Empty ->
 540       accu
 541   | Node ->
 542       let accu = fold_ascending (f, t.left, accu) in
 543       let accu = f (t.key, t.value, accu) in
 544       let accu = fold_ascending (f, t.right, accu) in
 545       accu
 546   end
 547 
 548 val rec fold_descending [k, a1, a2, b, p : perm] (
 549   f: (k, consumes d:        a1, consumes accu: b | p) -> (b | d @        a2),
 550          consumes t: tree k a1, consumes accu: b | p)  : (b | t @ tree k a2) =
 551   match t with
 552   | Empty ->
 553       accu
 554   | Node ->
 555       let accu = fold_descending (f, t.right, accu) in
 556       let accu = f (t.key, t.value, accu) in
 557       let accu = fold_descending (f, t.left, accu) in
 558       accu
 559   end
 560 
 561 (* -------------------------------------------------------------------------- *)
 562 
 563 (* The private function [join] has the same specification as [bal], except
 564    the left and right sub-trees may have arbitrary heights. *)
 565 
 566 val rec join [k, a] (
 567   consumes t: Node { left: tree k a; key: k; value: a; right: tree k a; height: unknown }
 568 ) : tree k a =
 569   let left, right = t.left, t.right in
 570   match t.left, t.right with
 571   | Empty, _ ->
 572       add_min_binding (t, right)
 573   | _, Empty ->
 574       add_max_binding (t, left)
 575   | Node, Node ->
 576       if left.height > right.height + 2 then begin
 577         (* The left node becomes the root. *)
 578         (* The root node becomes the right child. *)
 579         t.left <- left.right;
 580         left.right <- join t;
 581         bal left
 582       end
 583       else if right.height > left.height + 2 then begin
 584         t.right <- right.left;
 585         right.left <- join t;
 586         bal right
 587       end
 588       else begin
 589         update_height t;
 590         t
 591       end
 592   end
 593 
 594 (* -------------------------------------------------------------------------- *)
 595 
 596 (* The private function [concat] concatenates two trees of arbitrary heights.
 597    It is identical to [merge], except it calls [join] instead of [bal]. *)
 598 
 599 val concat [k, a] (consumes t1: tree k a, consumes t2: tree k a) : tree k a =
 600   match t1, t2 with
 601   | Empty, t -> t
 602   | t, Empty -> t
 603   | Node, Node ->
 604       let root = extract_min_binding t2 in
 605       root.left <- t1;
 606       join root
 607   end
 608 
 609 (* -------------------------------------------------------------------------- *)
 610 
 611 (* The private function [split] splits at a certain key. It returns a tree
 612    whose root node may or may not contain a value: note that the field
 613    [value] has type [option a] in the result type. This allows us to almost
 614    completely avoid memory allocation (and it is a natural thing to do anyway). *)
 615 
 616 val rec split [k, a] (cmp: (k, k) -> int, x: k, consumes t: tree k a)
 617   : Node { left: tree k a; key: unknown; value: option a; right: tree k a; height: unknown } =
 618   match t with
 619   | Empty ->
 620       (* Allocate a new node, containing no value, and whose sub-trees are empty. *)
 621       Node { left = Empty; key = (); value = none; right = t; height = () }
 622   | Node ->
 623       let c = cmp (x, t.key) in
 624       if c = 0 then begin
 625         (* We found the desired key. *)
 626         t.value <- some t.value; (* ah ha! look at this, feeble ML programmers *)
 627         t
 628       end
 629       else if c < 0 then begin
 630         let root = split (cmp, x, t.left) in
 631         t.left <- root.right;
 632         root.right <- join t;
 633         root
 634       end
 635       else begin
 636         let root = split (cmp, x, t.right) in
 637         t.right <- root.left;
 638         root.left <- join t;
 639         root
 640       end
 641   end
 642 
 643 (* -------------------------------------------------------------------------- *)
 644 
 645 (* The private function [concat_or_join] accepts a tree whose root node may or
 646    may not contain a value, and turns it intro a tree, using either [join] or
 647    [concat]. Thus, the left and right sub-trees are allowed to have arbitrary
 648    heights. *)
 649 
 650 val concat_or_join [k, a] (consumes t: Node { left: tree k a; key: k; value: option a; right: tree k a; height: unknown }) : tree k a =
 651   match t.value with
 652   | Some { contents = d } ->
 653       t.value <- d;
 654       join t
 655   | None ->
 656       concat (t.left, t.right)
 657   end
 658 
 659 (* -------------------------------------------------------------------------- *)
 660 
 661 (* Merging. *)
 662 
 663 (* This function has the same specification as its OCaml counterpart. *)
 664 
 665 val rec merge [k, a, b, c] (
 666   cmp: (k, k) -> int,
 667   f: (k, consumes option a, consumes option b) -> option c,
 668   consumes s1: tree k a,
 669   consumes s2: tree k b
 670 ) : tree k c =
 671   match s1, s2 with
 672   | Empty, Empty ->
 673       s1
 674   | _, _ ->
 675       if height s1 >= height s2 then
 676         match s1 with
 677         | Node ->
 678             let root2 = split (cmp, s1.key, s2) in
 679             root2.left <- merge (cmp, f, s1.left, root2.left);
 680             root2.key <- s1.key;
 681             root2.value <- f (s1.key, some s1.value, root2.value);
 682             root2.right <- merge (cmp, f, s1.right, root2.right);
 683             concat_or_join root2
 684         | Empty ->
 685             fail (* impossible *)
 686         end
 687       else
 688         match s2 with
 689         | Node ->
 690             let root1 = split (cmp, s2.key, s1) in
 691             root1.left <- merge (cmp, f, root1.left, s2.left);
 692             root1.key <- s2.key;
 693             root1.value <- f (s2.key, root1.value, some s2.value);
 694             root1.right <- merge (cmp, f, root1.right, s2.right);
 695             concat_or_join root1
 696         | Empty ->
 697             fail (* impossible *)
 698         end
 699   end
 700 
 701 (* -------------------------------------------------------------------------- *)
 702 
 703 (* Compared to OCaml's [filter], our [filter] is more general. Instead of
 704    returning a Boolean value, the function [p] returns an option. This
 705    allows us not only to drop certain entries, but also to modify the
 706    existing entries, and possibly to change their type. *)
 707 
 708 val rec filter [k, a, b] (
 709   p: (k, consumes a) -> option b,
 710   consumes t: tree k a
 711 ) : tree k b =
 712   match t with
 713   | Empty ->
 714       t
 715   | Node ->
 716       t.left <- filter (p, t.left);
 717       t.value <- p (t.key, t.value);
 718       t.right <- filter (p, t.right);
 719       concat_or_join t
 720   end
 721 
 722 (* The function [partition] is generalized in a similar manner. Instead
 723    of returning a Boolean value, it returns a choice. *)
 724 
 725 open choice (* TEMPORARY *)
 726 
 727 val rec partition [k, a, b, c] (
 728   p: (k, consumes a) -> choice b c,
 729   consumes t: tree k a
 730 ) : (tree k b, tree k c) =
 731   match t with
 732   | Empty ->
 733       t, Empty
 734   | Node ->
 735       let ll, lr = partition (p, t.left) in
 736       let choice = p (t.key, t.value) in
 737       let rl, rr = partition (p, t.right) in
 738       match choice with
 739       | Left ->
 740           t.left <- ll;
 741           t.value <- choice.contents;
 742           t.right <- rl;
 743           join t, concat (lr, rr)
 744       | Right ->
 745           t.left <- lr;
 746           t.value <- choice.contents;
 747           t.right <- rr;
 748           concat (ll, rl), join t
 749       end
 750       (* TEMPORARY why do I get a warning about this merge? isn't the expected type
 751          propagated all the way down? *)
 752   end
 753 
 754 (* -------------------------------------------------------------------------- *)
 755 
 756 (* Iterators. *)
 757 
 758 (* OCaml's [Map] library uses iterators in order to implement the comparison
 759    of two trees. It implements an iterator as a list of trees. Unfortunately,
 760    as of now, ordinary tree iterators are difficult to express in Mezzo,
 761    because it is hard to explain how/why the ownership of the tree cells (and
 762    keys, and values) is returned from the iterator to the tree once iteration
 763    is complete. *)
 764 
 765 (* Mutable zippers allow building a form of tree iterators that have the key
 766    feature that the memory footprint of the iterator is exactly the memory
 767    footprint of the original tree. Hence, keeping track of ownership is simpler
 768    in this approach. *)
 769 
 770 mutable data zipper k a =
 771   | ZEmpty
 772   | ZLeft  { father: zipper k a; key: k; value: a; right: tree k a;    height: int }
 773   | ZRight { left: tree k a;     key: k; value: a; father: zipper k a; height: int }
 774 
 775 (* In this approach, an iterator is a pair of a zipper and a tree, together
 776    with a tag that indicates whether we are arriving at this node, currently
 777    paused at this node, or leaving this node. Furthermore, an iterator can
 778    be paused only at a binary node -- never at an empty node. *)
 779 
 780 mutable data iterator k a =
 781   | IArriving { context: zipper k a; focus: tree k a }
 782   | IAt       { context: zipper k a; focus: Node { left: tree k a; key: k; value: a; right: tree k a; height: int }}
 783   | ILeaving  { context: zipper k a; focus: tree k a }
 784 
 785 (* The function [advance] advances an iterator until either it reaches a new
 786    element (in which case the iterator is left in state [IAt]) or there are no
 787    more elements (in which case the iterator is left in state [ILeaving], with
 788    an [Empty] tree). Initially, the iterator can be in any state; if it is
 789    initially in state [IAt], then it will advance to the next element. *)
 790 
 791 val rec advance [k, a] (i: iterator k a) : () =
 792   match i with
 793   | IArriving { context = z; focus = f } ->
 794       match f with
 795       | Empty ->
 796           (* Skip empty nodes. *)
 797           tag of i <- ILeaving;
 798           advance i
 799       | Node { left } ->
 800           (* When arriving at a node, descend immediately into the left child. *)
 801           tag of f <- ZLeft;
 802           f.father <- z;
 803           (* [f] is now a zipper! *)
 804           i.context <- f;
 805           i.focus <- left;
 806           advance i
 807         end
 808   | IAt { context = z; focus = f } ->
 809       let right = f.right in
 810       (* After handling a node, descend into its right child. *)
 811       tag of f <- ZRight;
 812       f.father <- z;
 813       (* [f] is now a zipper! *)
 814       tag of i <- IArriving;
 815       i.context <- f;
 816       i.focus <- right;
 817       advance i
 818   | ILeaving { context = z; focus = f } ->
 819       match z with
 820       | ZEmpty ->
 821           (* We are finished. *)
 822           ()
 823       | ZLeft { father } ->
 824           tag of z <- Node;
 825           z.left <- f;
 826           (* [z] is now a tree! *)
 827           (* After exiting a left child, pause at its father. *)
 828           tag of i <- IAt;
 829           i.context <- father;
 830           i.focus <- z
 831       | ZRight { father } ->
 832           tag of z <- Node;
 833           z.right <- f;
 834           (* [z] is now a tree! *)
 835           (* After exiting a right child, continue in ascending mode. *)
 836           i.context <- father;
 837           i.focus <- z;
 838           advance i
 839       end
 840   end
 841 
 842 (* Our iterators are unsatisfactory in that [advance] returns an internal
 843    representation of the iterator, instead of a nicely packaged pair of an
 844    element and a new iterator. For the time being, this is ok, because we are
 845    using the iterators only internally. TEMPORARY *)
 846 
 847 (* [iterate] turns a tree into a fresh iterator. *)
 848 
 849 val iterate [k, a] (consumes t: tree k a) : iterator k a =
 850   let i = IArriving { context = ZEmpty; focus = t } in
 851   advance i;
 852   i
 853 
 854 (* [stop] turns an iterator back into a complete tree. *)
 855 
 856 val rec stop [k, a] (consumes i: iterator k a) : tree k a =
 857   match i with
 858   | IAt { context = z; focus = f } ->
 859       match z with
 860       | ZEmpty ->
 861           f
 862       | ZLeft { father } ->
 863           tag of z <- Node;
 864           z.left <- f;
 865           (* [z] is now a tree! *)
 866           i.context <- father;
 867           i.focus <- z;
 868           stop i
 869       | ZRight { father } ->
 870           tag of z <- Node;
 871           z.right <- f;
 872           (* [z] is now a tree! *)
 873           i.context <- father;
 874           i.focus <- z;
 875           stop i
 876       end
 877   | ILeaving { context = ZEmpty; focus = f } ->
 878       f
 879   | _ ->
 880       fail (* impossible, if iterator has been properly [advance]d *)
 881   end
 882 
 883 (* [recover] stops the iterator [i] and dynamically checks that the resulting
 884    tree is equal to the argument [t]. This dynamic check is required for the
 885    moment. Assigning more precise types to our iterators might allow us to
 886    avoid it. TEMPORARY *)
 887 
 888 val recover [k, a] (consumes i: iterator k a, t: unknown) : (| t @ tree k a) =
 889   let u = stop i in
 890   if t == u then
 891     ()
 892   else
 893     fail
 894 
 895 (* -------------------------------------------------------------------------- *)
 896 
 897 (* We use an iterator to implement the comparison of two trees. *)
 898 
 899 val compare [k, a] (
 900   cmpk: (k, k) -> int,
 901   cmpa: (a, a) -> int,
 902   m1: tree k a,
 903   m2: tree k a
 904 ) : int =
 905 
 906   (* Create an iterator for each of the trees. *)
 907   let i1 = iterate m1
 908   and i2 = iterate m2 in
 909 
 910   (* Loop. *)
 911   let rec loop (| i1 @ iterator k a * i2 @ iterator k a) : int =
 912     match i1, i2 with
 913     | IAt { focus = m1 }, IAt { focus = m2 } ->
 914         let c = cmpk (m1.key, m2.key) in
 915         if c <> 0 then c else begin
 916           let c = cmpa (m1.value, m2.value) in
 917           if c <> 0 then c else begin
 918             advance i1;
 919             advance i2;
 920             loop()
 921           end
 922         end
 923     | IAt, _ ->
 924         (* The sequence [i1] is longer. *)
 925         1
 926     | _, IAt ->
 927         (* The sequence [i2] is longer. *)
 928         -1
 929     | _, _ ->
 930         (* The comparison suceeded, all the way. *)
 931         0
 932     end
 933 
 934   in
 935   let c : int = loop() in
 936   (* Stop the iterators and recover the permissions for the trees. *)
 937   recover (i1, m1);
 938   recover (i2, m2);
 939   c
 940   (* BUG well, not a bug, but a feature wish: if I omit the calls to
 941      recover above, I get a good error message, but with a completely
 942      useless location (the entire function body). Since we are
 943      propagating expected types down, couldn't we signal the error
 944      message at the sub-expression "c" (final line), instead of at
 945      the level of the entire function? *)
 946 
 947 (* -------------------------------------------------------------------------- *)
 948 
 949 (* As a test of the type-checker, we re-implement [iter] using iterators. *)
 950 
 951 val iter_variant [k, a, p : perm] (
 952   f: (k,           a | p) -> bool,
 953          t: tree k a | p)  : bool =
 954 
 955   (* Create an iterator. *)
 956   let i = iterate t in
 957   (* Loop. *)
 958   let rec loop (| i @ iterator k a * p) : bool =
 959     match i with
 960     | IAt { focus = m } ->
 961         f (m.key, m.value) && loop()
 962     | _ ->
 963         True
 964     end
 965   in
 966   (* Stop the iterator and recover the permission for the tree. *)
 967   let outcome : bool = loop() in
 968   recover (i, t);
 969   outcome
 970 
 971 (* -------------------------------------------------------------------------- *)
 972 
 973 (* Iterators in another style. *)
 974 
 975 (* A bandit is a binary tree node that has no left child. *)
 976 
 977 (* The following type describes a list of bandits. *)
 978 
 979 (* TEMPORARY because we don't have type abbreviations, we redefine this type
 980    instead of using an abbreviation for a list of bandits. *)
 981 
 982 data bandits k a =
 983   | End
 984   | More {
 985       head: Node { left: dynamic; key: k; value: a; right: tree k a; height: int };
 986       tail: bandits k a
 987     }
 988 
 989 (* An enumeration is basically a reference to a list of bandits. Furthermore,
 990    an enumeration object is able to adopt the tree nodes that have already
 991    been visited. The nodes that are adopted are isolated, i.e. they do not
 992    own their children; this forces us (unfortunately) to redefine a new type
 993    of tree nodes that do not own their children. *)
 994 
 995 (* TEMPORARY perhaps we could plan ahead and parameterize the type of tree
 996    nodes with the type of their children before tying the recursive knot?
 997    but that would require a recursive type abbreviation, I am afraid *)
 998 
 999 mutable data enumeration k a =
1000   Enum { bandits: bandits k a } adopts visited k a
1001 
1002 mutable data visited k a =
1003   | VisitedEmpty
1004   | VisitedNode { left: dynamic; key: k; value: a; right: dynamic; height: int }
1005 
1006 (* This smart constructor [cons]es a tree in front of a list of bandits. The
1007    left spine of the tree is walked down all the way, so that the nodes that
1008    are effectively insert into the list are bandits. *)
1009 
1010 val rec cons_bandits [k, a] (consumes t: tree k a, e: enumeration k a) : () =
1011   match t with
1012   | Empty ->
1013       tag of t <- VisitedEmpty; (* this is a no-op *)
1014       give t to e
1015   | Node ->
1016       e.bandits <- More { head = t; tail = e.bandits };
1017       cons_bandits (t.left, e)
1018   end
1019 
1020 val new_enum [k, a] (consumes t: tree k a) : enumeration k a =
1021   let e = Enum { bandits = End } in
1022   cons_bandits (t, e);
1023   e
1024 
1025 val consume_enum [k, a] (e: enumeration k a) : () =
1026   match e.bandits with
1027   | More { head; tail } ->
1028       e.bandits <- tail;
1029       cons_bandits (head.right, e);
1030       tag of head <- VisitedNode; (* this is a no-op *)
1031       give head to e
1032   | End ->
1033       fail
1034       (* We could eliminate this case if we declared that initially [e] has type
1035          Enum { bandits: More { ... } }. But this type would be hugely verbose,
1036          plus we would then have to add a [consumes] annotation and write the
1037          post-condition [e @ enumeration k a]. Abandoned. *)
1038   end
1039 
1040 (* The following function reconstructs the ownership of the original tree after
1041    the enumeration has stopped. We are in a situation where the tree nodes that
1042    have already been visited have been adopted by the enumeration object, whereas
1043    the tree nodes that have not yet been visited are listed, *in order*, as part
1044    of the list [e.bandits]. A dynamic ownership test is used to distinguish these
1045    two situations. Yes, this is quite crazy, and I am not even sure that it works. *)
1046 
1047 val rec reconstruct [k, a] (t: dynamic, e: enumeration k a) : (| t @ tree k a) =
1048   if e owns t then begin
1049     take t from e;
1050     match t with
1051     | VisitedEmpty ->
1052         tag of t <- Empty (* this is a no-op *)
1053     | VisitedNode ->
1054         tag of t <- Node; (* this is a no-op *)
1055         reconstruct (t.left, e);
1056         reconstruct (t.right, e)
1057     end
1058   end
1059   else
1060     match e.bandits with
1061     | End ->
1062         fail (* impossible *)
1063     | More { head; tail } ->
1064         if head == t then begin
1065           (* At this point, [t] is a bandit, a tree node that does not own
1066              its left child. We need to reconstruct this left child in order
1067              to obtain [t @ tree k a]. *)
1068           e.bandits <- tail;
1069           reconstruct (t.left, e)
1070         end
1071         else
1072           fail (* impossible *)
1073           (* TEMPORARY I need to think more about the order of the non-visited
1074              trees in e.bandits; I am afraid they will *not* appear in the
1075              desired order, so the test [head == t] *WILL* FAIL! *)
1076     end
1077 
1078 (* We use an enumeration to implement the comparison of two trees. *)
1079 
1080 val compare_variant [k, a] (
1081   cmpk: (k, k) -> int,
1082   cmpa: (a, a) -> int,
1083   consumes m1: tree k a,
1084   consumes m2: tree k a
1085 ) : int =
1086 
1087   (* Create an enumeration for each of the trees. *)
1088   let e1 = new_enum m1
1089   and e2 = new_enum m2 in
1090 
1091   (* Loop. *)
1092   let rec loop (| e1 @ enumeration k a * e2 @ enumeration k a) : int =
1093     match e1.bandits, e2.bandits with
1094     | More { head = m1 }, More { head = m2 } ->
1095         let c = cmpk (m1.key, m2.key) in
1096         if c <> 0 then c else begin
1097           let c = cmpa (m1.value, m2.value) in
1098           if c <> 0 then c else begin
1099             consume_enum e1;
1100             consume_enum e2;
1101             loop()
1102           end
1103         end
1104     | More, End ->
1105         (* The sequence [i1] is longer. *)
1106         1
1107     | End, More ->
1108         (* The sequence [i2] is longer. *)
1109         -1
1110     | End, End ->
1111         (* The comparison suceeded, all the way. *)
1112         0
1113     end
1114 
1115   in
1116   loop()
1117 
1118 (* -------------------------------------------------------------------------- *)
1119 
1120 (* Conversion of a tree to a list. *)
1121 
1122 val bindings [k, a] duplicable k => duplicable a => (t: tree k a) : list::list (k, a) =
1123   let f (x: k, v: a, accu: list::list (k, a)) : list::list (k, a) =
1124     list::cons [(k, a)] ((x, v), accu)
1125   in
1126   fold_descending (f, t, list::nil)
1127 
1128 (* -------------------------------------------------------------------------- *)
1129 (* -------------------------------------------------------------------------- *)
1130 
1131 (* We now wrap the type [tree] in another type, [treeMap], which the client
1132    will work with. There are two reasons for doing so. One is that this allows
1133    some functions, such as [add], to return unit instead of returning a new
1134    data structure. The other is that this allows us to store the comparison
1135    function. *)
1136 
1137 mutable data treeMap k (c : term) a =
1138   TreeMap { tree: tree k a; cmp: =c | c @ (k, k) -> int }
1139 
1140 val cardinal [k, c : term, a] (m: treeMap k c a) : int =
1141   cardinal m.tree
1142 
1143 (* -------------------------------------------------------------------------- *)
1144 
1145 (* Creating an empty map requires supplying a comparison function [cmp],
1146    which is stored within the new data structure. *)
1147 
1148 val create [k, a] (cmp: (k, k) -> int) : treeMap k cmp a =
1149   TreeMap { tree = create(); cmp }
1150 
1151 val singleton [k, a] (cmp: (k, k) -> int, consumes x: k, consumes d: a): treeMap k cmp a =
1152   TreeMap { tree = singleton (x, d); cmp }
1153 
1154 val is_empty [k, c: term, a] (m : treeMap k c a) : bool =
1155   is_empty m.tree
1156 
1157 val add [k, c: term, a] (
1158   consumes x: k,
1159   consumes d: a,
1160   m: treeMap k c a
1161 ) : () =
1162   m.tree <- add (m.cmp, x, d, m.tree)
1163 
1164 val find [k, c: term, a] duplicable a => (
1165   x: k,
1166   m: treeMap k c a
1167 ) : option a =
1168   find (m.cmp, x, m.tree)
1169 
1170 val update [k, c: term, a, preserved : perm, consumed : perm] (
1171   x: k,
1172   m: treeMap k c a,
1173   f: (consumes a | preserved * consumes consumed) -> a
1174 | preserved * consumes consumed
1175 ) : () =
1176   update [k, a, preserved, consumed] (m.cmp, x, m.tree, f)
1177 
1178 val mem [k, c: term, a] (x: k, m: treeMap k c a) : bool =
1179   mem (m.cmp, x, m.tree)
1180 
1181 val min_binding [k, c: term, a] duplicable k => duplicable a => (m : treeMap k c a) : option (k, a) =
1182   min_binding m.tree
1183 
1184 val max_binding [k, c: term, a] duplicable k => duplicable a => (m : treeMap k c a) : option (k, a) =
1185   max_binding m.tree
1186 
1187 val extract_min_binding [k, c: term, a] (m: treeMap k c a) : option (k, a) =
1188   match m.tree with
1189   | Empty ->
1190       none
1191   | Node ->
1192       let node = extract_min_binding m.tree in
1193       m.tree <- node.right;
1194       some [(k, a)] (node.key, node.value)
1195   end
1196 
1197 val extract_max_binding [k, c: term, a] (m: treeMap k c a) : option (k, a) =
1198   match m.tree with
1199   | Empty ->
1200       none
1201   | Node ->
1202       let node = extract_max_binding m.tree in
1203       m.tree <- node.left;
1204       some [(k, a)] (node.key, node.value)
1205   end
1206 
1207 val remove [k, c: term, a] (x: k, m: treeMap k c a) : option (k, a) =
1208   let dst = newref () in
1209   m.tree <- remove (m.cmp, x, m.tree, dst);
1210   !dst
1211 
1212 val iter [k, c: term, a, p : perm] (
1213   m: treeMap k c a,
1214   f: (k, a | p) -> bool
1215   | p
1216 ) : bool =
1217   iter (f, m.tree)
1218 
1219 val for_all =
1220   iter
1221 
1222 (* [exists] could be implemented directly, but an implementation in
1223    terms of [for_all] is preferred, as a test of the type-checker. *)
1224 
1225 val exists [k, c: term, a, p : perm] (
1226   m: treeMap k c a,
1227   f: (k, a | p) -> bool
1228  | p
1229 ) : bool =
1230   not (for_all (m, fun (key: k, value: a | p) : bool =
1231     not (f (key, value))
1232   ))
1233 
1234 val map [k, c: term, a1, a2, b, p : perm] duplicable k => (
1235   consumes m: treeMap k c a1,
1236   f: (k, consumes d: a1 | p) -> (b | d @ a2)
1237   | p
1238 ) : (treeMap k c b | m @ treeMap k c a2) =
1239   TreeMap { tree = map (f, m.tree); cmp = m.cmp }
1240 
1241 (* [copy] could be defined directly, but is here defined as a special case of [map]. *)
1242 
1243 val copy [k, c: term, a, b] duplicable k => (m: treeMap k c a, f: a -> b) : treeMap k c b =
1244   map (m, fun (x: k, d: a) : b = f d)
1245 
1246 val fold_ascending [k, c: term, a1, a2, b, p : perm] (
1247   consumes m: treeMap k c a1,
1248   consumes accu: b,
1249   f: (k, consumes d: a1, consumes accu: b | p) -> (b | d @ a2)
1250   | p
1251 ) : (b | m @ treeMap k c a2) =
1252   fold_ascending (f, m.tree, accu)
1253 
1254 val fold_descending [k, c: term, a1, a2, b, p : perm] (
1255   consumes m: treeMap k c a1,
1256   consumes accu: b,
1257   f: (k, consumes d: a1, consumes accu: b | p) -> (b | d @ a2)
1258   | p
1259 ) : (b | m @ treeMap k c a2) =
1260   fold_descending (f, m.tree, accu)
1261 
1262 val fold =
1263   fold_ascending
1264 
1265 val merge [k, cmp: term, a, b, c] (
1266   consumes m1: treeMap k cmp a,
1267   consumes m2: treeMap k cmp b,
1268   f: (k, consumes option a, consumes option b) -> option c
1269 ) : treeMap k cmp c =
1270   m1.tree <- merge (m1.cmp, f, m1.tree, m2.tree);
1271   m1
1272 
1273 val split [k, c: term, a] (
1274   x: k,
1275   consumes m: treeMap k c a
1276 ) : (treeMap k c a, option a, treeMap k c a) =
1277   let root = split (m.cmp, x, m.tree) in
1278   m.tree <- root.left;
1279   m, root.value, TreeMap { tree = root.right; cmp = m.cmp }
1280 
1281 val filter [k, c: term, a, b] (
1282   consumes m: treeMap k c a,
1283   p: (k, consumes a) -> option b
1284 ) : treeMap k c b =
1285   m.tree <- filter (p, m.tree);
1286   m
1287 
1288 val partition [k, cmp: term, a, b, c] (
1289   consumes m: treeMap k cmp a,
1290   p: (k, consumes a) -> choice b c
1291 ) : (treeMap k cmp b, treeMap k cmp c) =
1292   let left, right = partition (p, m.tree) in
1293   m.tree <- left;
1294   m, TreeMap { tree = right; cmp = m.cmp }
1295 
1296 (* We might wish to make [compare] a unary function
1297    of [cmp] to a binary function of [(m1, m2)]. *)
1298 
1299 val compare [k, c: term, a] (
1300   cmp: (a, a) -> int,
1301   m1: treeMap k c a,
1302   m2: treeMap k c a
1303 ) : int =
1304   compare (m1.cmp, cmp, m1.tree, m2.tree)
1305 
1306 val equal [k, c: term, a] (
1307   cmp: (a, a) -> int,
1308   m1: treeMap k c a,
1309   m2: treeMap k c a
1310 ) : bool =
1311   compare (cmp, m1, m2) = 0
1312 
1313 val bindings [k, c: term, a] duplicable k => duplicable a => (
1314   m: treeMap k c a
1315 ) : list::list (k, a) =
1316   bindings m.tree
1317 
1318 (* TEMPORARY compare .mzi file with map.mli *)
1319 
1320 (*
1321 Local Variables:
1322 compile-command: "../mezzo mutableTreeMap.mz"
1323 End:
1324 *)