TAMING THE CURSE OF DIMENSIONALITY: DISCRETE INTEGRATION BY HASHING AND OPTIMIZATION Stefano Ermon*, Carla P. Gomes*, Ashish Sabharwal +, and Bart Selman* *Cornell University + IBM Watson Research Center ICML
High-dimensional integration High-dimensional integrals in statistics, ML, physics Expectations / model averaging Marginalization Partition function / rank models / parameter learning Curse of dimensionality: Quadrature involves weighted sum over exponential number of items (e.g., units of volume) L L2L2 L3L3 LnLn n dimensional hypercube L4L4 2
Discrete Integration We are given A set of 2 n items Non-negative weights w Goal: compute total weight Compactly specified weight function: factored form (Bayes net, factor graph, CNF, …) potentially Turing Machine Example 1: n=2 dimensions, sum over 4 items Example 2: n= 100 dimensions, sum over ≈10 30 items (intractable) … 2 n Items Goal: compute = Size visually represents weight
Hardness 0/1 weights case: Is there at least a “1”? SAT How many “1” ? #SAT NP-complete vs. #P-complete. Much harder General weights: Find heaviest item (combinatorial optimization) Sum weights (discrete integration) This Work: Approximate Discrete Integration via Optimization Combinatorial optimization (MIP, Max-SAT,CP) also often fast in practice: Relaxations / bounds Pruning P NP P^#P PSPACE Easy Hard PH EXP
Previous approaches: Sampling Idea: Randomly select a region Count within this region Scale up appropriately Advantage: Quite fast Drawback: Robustness: can easily under- or over-estimate Scalability in sparse spaces: e.g items with non-zero weight out of means need region much larger than to “hit” one Can be partially mitigated using importance sampling
Previous approaches: Variational methods Idea: For exponential families, use convexity Variational formulation (optimization) Solve approximately (using message- passing techniques) Advantage: Quite fast Drawback: Objective function is defined indirectly Cannot represent the domain of optimization compactly Need to be approximated (BP, MF) Typically no guarantees 6
# items CDF-style plot Suppose items are sorted by weight A new approach : WISH b 0 =100b 1 =70 b 2 =9 b 3 = b 4 =2 22 Geometrically divide y axis Area under the curve equals the total weight we want to compute Geometrically increasing bin sizes Given the endpoints b i, we have a 2-approximation 2 i -largest weight (quantile) b i b How many items with weight at least b How to estimate? Divide into slices and sum up Can bound area in each slice within a factor of 2 How to estimate the b i ? 7 1 w Also works if we have approximations M i of b i
Hash 2 n items into 2 i buckets, then look at a single bucket. Find heaviest weight w i in the bucket. Estimating the endpoints (quantiles) b i For i=2, hashing 16 items into 2 2 =4 buckets INTUITION. Repeat several times. With High Probability: w i often found to be larger than w* there are at least 2 i items with weight larger than w*. W i =9 8
Hashing and Optimization Hash into 2 i buckets, then look at a single bucket With probability >0.5: There is nothing from the small set (vanishes) There is something from the larger set (survives) 2 i-2 =2 i /4 heaviest items 2 i+2 =4. 2 i heaviest items Something in here is likely to be in the bucket, so if we take a max, it will be in this range b i-2 b i+2 2 bibi Geometrically increasing bin sizes b0b0 Remember items are sorted so max picks the “rightmost” item… times larger increasing weight 9
Represent each item as an n-bit vector x Randomly generate A in {0,1} i×n,b in {0,1} i Then A x + b (mod 2) is: Uniform Pairwise independent Universal Hashing Max w(x) subject to A x = b mod 2 is in here “frequently” n Repeat several times. Median is in the desired range with high probability Bucket content is implicitly defined by the solutions of A x = b mod 2 (parity constraints) A i x = b (mod 2) b i+2 bibi b0b0 b i-2 x x xx x 10
WISH : Integration by Hashing and Optimization WISH ( WeightedIntegralsSumsByHashing ) T = log (n/δ) For i = 0, …, n For t = 1, …,T Sample uniformly A in {0,1} i×n, b in {0,1} i w i t = max w(x) subject to A x = b (mod 2) M i = Median (w i 1, …, w i T ) Return M 0 + Σ i M i+1 2 i The algorithm requires only O(n log n) optimizations for a sum over 2 n items M i estimates the 2 i -largest weight b i Outer Loop over n+1 endpoints of the n slices (b i ) Sum up estimated area in each vertical slice 11 Hash into 2 i buckets Find heaviest item Repeat log(n) times # items CDF-style plot
Visual working of the algorithm How it works …. median M 1 1 random parity constraint 2 random parity constraints …. 3 random parity constraints median M 2 median M 3 …. Mode M ×1×2 ×4 + … 12 Function to be integrated n times Log(n) times
Accuracy Guarantees Theorem 1: With probability at least 1- δ (e.g., 99.9%) WISH computes a 16-approximation of a sum over 2 n items (discrete integral) by solving θ(n log n) optimization instances. Example: partition function by solving θ(n log n) MAP queries Theorem 2: Can improve the approximation factor to (1+ε) by adding extra variables. Example: factor 2 approximation with 4n variables Byproduct: we also obtain a 8-approximation of the tail distribution (CDF) with high probability 13
Key features Strong accuracy guarantees Can plug in any combinatorial optimization tool Bounds on the optimization translate to bounds on the sum Stop early and get a lower bound (anytime) (LP,SDP) relaxations give upper bounds Extra constraints can make the optimization harder or easier Massively parallel (independent optimizations) Remark: faster than enumeration force only when combinatorial optimization is efficient (faster than brute force). 14
Experimental results Approximate the partition function of undirected graphical models by solving MAP queries (find most likely state) Normalization constant to evaluate probability, rank models MAP inference on graphical model augmented with random parity constraints Toulbar2 (branch&bound) solver for MAP inference Augmented with Gauss-Jordan filtering to efficiently handle the parity constraints (linear equations over a field) Run in parallel using > 600 cores 15 Parity check nodes enforcing A x = b (mod 2) Original graphical model
Sudoku How many ways to fill a valid sudoku square? Sum over 9 81 ~ possible squares (items) w(x)=1 if it is a valid square, w(x)=0 otherwise Accurate solution within seconds: 1.634×10 21 vs 6.671× …. ? 16
Random Cliques Ising Models Very small error band is the 16- approximation range Strength of the interactions 17 Other methods fall way out of the error band Partition function MAP query
Model ranking - MNSIT Use the function estimate to rank models (data likelihood) WISH ranks them correctly. Mean-field and BP do not. 18 Visually, a better model for handwritten digits
Conclusions Discrete integration reduced to small number of optimization instances Strong (probabilistic) accuracy guarantees by universal hashing Can leverage fast combinatorial optimization packages Works well in practice Future work: Extension to continuous integrals Further approximations in the optimization [UAI -13] Coding theory / Parity check codes / Max-Likelihood Decoding LP relaxations Sampling from high-dimensional probability distributions? 19