RangeTree.java


Below is the syntax highlighted version of RangeTree.java from §9.2 Geometric Search.


/******************************************************************************
 *  Compilation:  javac RangeTree.java
 *  Execution:    java RangeTree
 *  
 *  Range tree implemented using binary search trees.
 *
 *  Assumes no two points have the same x or y coordinate!!!!
 *
 *  Could be made more efficient by assuming points are given
 *  all at once. Then could sort by x and by y to ensure tree
 *  are perfectly balanced.
 * 
 ******************************************************************************/

public class RangeTree<Key extends Comparable<Key>>  {

    private Node root;   // root of the primary BST

    // BST helper node data type
    private class Node {
        Key x, y;                   // x- and y- coordinates
        Node left, right;           // left and right subtrees
        RangeSearch<Key, Key> bst;  // secondary BST

        Node(Key x, Key y) {
            this.x   = x;
            this.y   = y;
            this.bst = new RangeSearch<Key, Key>();
            this.bst.put(y, x);
        }
    }


  /***********************************************************************
    *  Insert (x, y) into range tree.
    *  Sort by x in primary BST; sort by y in secondary BST.
    ***************************************************************************/
    public void insert(Key x, Key y) {
        root = insert(root, x, y);
    }

    private Node insert(Node h, Key x, Key y) {
        if (h == null) return new Node(x, y);
        h.bst.put(y, x);
        if (less(x, h.x)) h.left  = insert(h.left,  x, y);
        else              h.right = insert(h.right, x, y);
        return h;
    }


   /***************************************************************************
    *  Range searching.
    ***************************************************************************/


    // print all points in range
    public void query2D(Interval2D<Key> rect) {
        Interval<Key> intervalX = rect.intervalX;
        // find splitting node h where h.x is in the x-interval
        Node h = root;
        while (h != null && !intervalX.contains(h.x)) {
            if      (less(intervalX.max(), h.x)) h = h.left;
            else if (less(h.x, intervalX.min())) h = h.right;
        }
        if (h == null) return;

        if (rect.contains(h.x, h.y)) StdOut.println("A: " + h.x + ", " + h.y);

        queryL(h.left,  rect);
        queryR(h.right, rect);
    }

    // find all keys >= xmin in subtree rooted at h
    private void queryL(Node h, Interval2D<Key> rect) {
        if (h == null) return;
        if (rect.contains(h.x, h.y)) StdOut.println("B: " + h.x + ", " + h.y);
        if (!less(h.x, rect.intervalX.min())) {
            enumerate(h.right, rect);
            queryL(h.left, rect);
        }
        else {
            queryL(h.right, rect);
        }
    }

    // find all keys <= xmax in subtree rooted at h
    private void queryR(Node h, Interval2D<Key> rect) {
        if (h == null) return;
        if (rect.contains(h.x, h.y)) StdOut.println("C: " + h.x + ", " + h.y);
        if (!less(rect.intervalX.max(), h.x)) {
            enumerate(h.left, rect);
            queryR(h.right, rect);
        }
        else {
            queryR(h.left, rect);
        }
    }

    // precondition: subtree rooted at h has keys between xmin and xmax
    private void enumerate(Node h, Interval2D<Key> rect) {
        if (h == null) return;
        StdOut.println("integrity: " + h.bst.check());
        Iterable<Key> list = h.bst.range(rect.intervalY);
        for (Key y : list) {
            Key x = h.bst.get(y);
            StdOut.println("D: " + x + ", " + y);
        }
        StdOut.println("-");
    }


   /***************************************************************************
    *  Helper comparison functions.
    ***************************************************************************/

    private boolean less(Key k1, Key k2) {
        return k1.compareTo(k2) < 0;
    }


   /***************************************************************************
    *  Test client.
    ***************************************************************************/
    public static void main(String[] args) {
        int M = Integer.parseInt(args[0]);   // queries
        int N = Integer.parseInt(args[1]);   // points

        RangeTree<Integer> st = new RangeTree<Integer>();

        // insert N random points in the unit square
        for (int i = 0; i < N; i++) {
            int x = StdRandom.uniform(100);
            int y = StdRandom.uniform(100);
            StdOut.println("(" + x + ", " + y + ")");
            st.insert(x, y);
        }
        StdOut.println("Done preprocessing " + N + " points");

        // do some range searches
        for (int i = 0; i < M; i++) {
            int xmin = StdRandom.uniform(100);
            int ymin = StdRandom.uniform(100);;
            int xmax = xmin + StdRandom.uniform(10);
            int ymax = ymin + StdRandom.uniform(20);
            Interval<Integer> intX = new Interval<Integer>(xmin, xmax);
            Interval<Integer> intY = new Interval<Integer>(ymin, ymax);
            Interval2D<Integer> rect = new Interval2D<Integer>(intX, intY);
            StdOut.println(rect);
            st.query2D(rect);
        }
    }

}


Copyright © 2000–2017, Robert Sedgewick and Kevin Wayne.
Last updated: Fri Oct 20 12:50:46 EDT 2017.