Below is the syntax highlighted version of RangeTree.java
from §9.2 Geometric Search.
/****************************************************************************** * Compilation: javac RangeTree.java * Execution: java RangeTree * * Range tree implemented using binary search trees. * * Assumes no two points have the same x or y coordinate!!!! * * Could be made more efficient by assuming points are given * all at once. Then could sort by x and by y to ensure tree * are perfectly balanced. * ******************************************************************************/ public class RangeTree<Key extends Comparable<Key>> { private Node root; // root of the primary BST // BST helper node data type private class Node { Key x, y; // x- and y- coordinates Node left, right; // left and right subtrees RangeSearch<Key, Key> bst; // secondary BST Node(Key x, Key y) { this.x = x; this.y = y; this.bst = new RangeSearch<Key, Key>(); this.bst.put(y, x); } } /*********************************************************************** * Insert (x, y) into range tree. * Sort by x in primary BST; sort by y in secondary BST. ***************************************************************************/ public void insert(Key x, Key y) { root = insert(root, x, y); } private Node insert(Node h, Key x, Key y) { if (h == null) return new Node(x, y); h.bst.put(y, x); if (less(x, h.x)) h.left = insert(h.left, x, y); else h.right = insert(h.right, x, y); return h; } /*************************************************************************** * Range searching. ***************************************************************************/ // print all points in range public void query2D(Interval2D<Key> rect) { Interval<Key> intervalX = rect.intervalX; // find splitting node h where h.x is in the x-interval Node h = root; while (h != null && !intervalX.contains(h.x)) { if (less(intervalX.max(), h.x)) h = h.left; else if (less(h.x, intervalX.min())) h = h.right; } if (h == null) return; if (rect.contains(h.x, h.y)) StdOut.println("A: " + h.x + ", " + h.y); queryL(h.left, rect); queryR(h.right, rect); } // find all keys >= xmin in subtree rooted at h private void queryL(Node h, Interval2D<Key> rect) { if (h == null) return; if (rect.contains(h.x, h.y)) StdOut.println("B: " + h.x + ", " + h.y); if (!less(h.x, rect.intervalX.min())) { enumerate(h.right, rect); queryL(h.left, rect); } else { queryL(h.right, rect); } } // find all keys <= xmax in subtree rooted at h private void queryR(Node h, Interval2D<Key> rect) { if (h == null) return; if (rect.contains(h.x, h.y)) StdOut.println("C: " + h.x + ", " + h.y); if (!less(rect.intervalX.max(), h.x)) { enumerate(h.left, rect); queryR(h.right, rect); } else { queryR(h.left, rect); } } // precondition: subtree rooted at h has keys between xmin and xmax private void enumerate(Node h, Interval2D<Key> rect) { if (h == null) return; StdOut.println("integrity: " + h.bst.check()); Iterable<Key> list = h.bst.range(rect.intervalY); for (Key y : list) { Key x = h.bst.get(y); StdOut.println("D: " + x + ", " + y); } StdOut.println("-"); } /*************************************************************************** * Helper comparison functions. ***************************************************************************/ private boolean less(Key k1, Key k2) { return k1.compareTo(k2) < 0; } /*************************************************************************** * Test client. ***************************************************************************/ public static void main(String[] args) { int M = Integer.parseInt(args[0]); // queries int N = Integer.parseInt(args[1]); // points RangeTree<Integer> st = new RangeTree<Integer>(); // insert N random points in the unit square for (int i = 0; i < N; i++) { int x = StdRandom.uniform(100); int y = StdRandom.uniform(100); StdOut.println("(" + x + ", " + y + ")"); st.insert(x, y); } StdOut.println("Done preprocessing " + N + " points"); // do some range searches for (int i = 0; i < M; i++) { int xmin = StdRandom.uniform(100); int ymin = StdRandom.uniform(100);; int xmax = xmin + StdRandom.uniform(10); int ymax = ymin + StdRandom.uniform(20); Interval<Integer> intX = new Interval<Integer>(xmin, xmax); Interval<Integer> intY = new Interval<Integer>(ymin, ymax); Interval2D<Integer> rect = new Interval2D<Integer>(intX, intY); StdOut.println(rect); st.query2D(rect); } } }