$ cat sxp.rkt #lang racket (require racket/set) (define (smallest-at-least-po2 a b) ; Return the smallest power of 2 that is at least the given ; values. (define (bit-size i) ; Return the smallest number of bits required to represent i. (if (< i 2) 1 (let ((n (inexact->exact (ceiling (/ (log i) (log 2)))))) (+ n (if (= (expt 2 n) i) 1 0))))) (expt 2 (max (bit-size a) (bit-size b)))) ; These functions treat numbers as n-bit values, where n is the ; smallest number of bits capable of representing the sum and xor ; values. They also ignore overflow, which means it's possible ; to have pair values that are larger than the sum value but ; still add to the sum value. Also, because xor and sum are ; commutative, the number pairs are created with the smaller ; value first (on the left, the car). (define (sum-xor-pairs-n ab-sum ab-xor) ; Return a set of number pairs such that, for each pair, the ; sum of the two numbers equals ab-sum (ignoring overflow) and ; the xor of the two numbers equlas ab-xor. ; This function does work proportional to the power of the ; larger of the smallest number of bits needed to represent the ; sum and xor values, which is roughly linear in the larger of ; the sum and xor values. ; This function essentially implements the truth tables sum and ; xor: ; ; Ci a b Co s x ; 0 0 0 0 0 0 ; 0 0 1 0 1 1 ; 0 1 0 0 1 1 ; 0 1 1 1 0 0 ; 1 0 0 0 1 0 ; 1 0 1 1 0 1 ; 1 1 0 1 0 1 ; 1 1 1 1 1 0 ; ; Ci the carry-in bit ; a a bit from one of the pair values ; b the corresponding bit from the other pair value ; Co the carry-out from Ci + a + b ; s Ci + a + b ; x a xor b ; ; Interesting things to note about this table: ; ; The s and x bits are the same when Ci = 0. ; The s and x bits are different when Ci = 1. ; xor ignores Ci (and Co). (define (add-msb a-msb b-msb ab-pairs) ; Add the given most-significant bits to the given set of ; number pairs; return the new set. (for/set ([v ab-pairs]) (vector (cons a-msb (vector-ref v 0)) (cons b-msb (vector-ref v 1))))) (define (vector->pair ab-pairs) ; Return a set of number pairs, where each number corresponds ; to a bit list in the given pairs and each pair corresponds ; to a vector in the given set. The smaller value appears ; first in the pair. (define (implode bit-list) ; Return the number equivalent to the given bit list (msb on ; the left). (let loop ((n 0) (bit-list bit-list)) (if (null? bit-list) n (loop (+ (* n 2) (car bit-list)) (cdr bit-list))))) (for/set ((v ab-pairs)) (let ((a (implode (vector-ref v 0))) (b (implode (vector-ref v 1)))) (cons (min a b) (max a b))))) (define (oops emsg) (raise-arguments-error 'sum-xor-pairs-n "some unfathomable error")) (vector->pair (let loop ((ab-sum ab-sum) (ab-xor ab-xor) (carry-in 0) (ab-pairs (if (and (= ab-sum 0) (= ab-xor 0)) (set #((0) (0)) #((1) (1))) (set #(() ()))))) (if (and (= ab-sum 0) (= ab-xor 0)) ab-pairs (let ((sum-bit (remainder ab-sum 2)) (xor-bit (remainder ab-xor 2)) (ab-sum (quotient ab-sum 2)) (ab-xor (quotient ab-xor 2))) (cond ((= carry-in 0) (cond ((and (= sum-bit 0) (= xor-bit 0)) (set-union (loop ab-sum ab-xor 0 (add-msb 0 0 ab-pairs)) (loop ab-sum ab-xor 1 (add-msb 1 1 ab-pairs)))) ((and (= sum-bit 1) (= xor-bit 1)) (set-union (loop ab-sum ab-xor 0 (add-msb 0 1 ab-pairs)) (loop ab-sum ab-xor 0 (add-msb 1 0 ab-pairs)))) ((not (= sum-bit xor-bit)) ; If the carry-in's zero, the sum of the two ; bits (ignoring carry-out) must equal the xor ; of the two bits. Because that's not the ; case, there can be no solutions down this ; branch. (set)) (#t (oops)))) ((= carry-in 1) (cond ((and (= sum-bit 1) (= xor-bit 0)) (set-union (loop ab-sum ab-xor 0 (add-msb 0 0 ab-pairs)) (loop ab-sum ab-xor 1 (add-msb 1 1 ab-pairs)))) ((and (= sum-bit 0) (= xor-bit 1)) (set-union (loop ab-sum ab-xor 1 (add-msb 0 1 ab-pairs)) (loop ab-sum ab-xor 1 (add-msb 1 0 ab-pairs)))) ((= sum-bit xor-bit) ; If the carry-in's one, the sum of the two ; bits (ignoring carry-out) cannot equal the ; xor of the two bits. Because that's not the ; case, there can be no solutions down this ; branch. (set)) (#t (oops)))) (#t (oops)))))))) (define (sum-xor-pairs-nsq ab-sum ab-xor) ; Return a number-pair set such that, for each pair, the sum of ; the two numbers equals ab-sum (ignoring overflow) and the xor ; of the two numbers equlas ab-xor. ; This function does work proportinal to ab-sum*ab-xor ; (a.k.a. n-squared). (let ((N (smallest-at-least-po2 ab-sum ab-xor))) (let outer-loop ((a 0) (ab-pairs (set))) (if (= a N) ab-pairs (let inner-loop ((b a) (ab-pairs ab-pairs)) (if (= b N) (outer-loop (+ a 1) ab-pairs) (inner-loop (+ b 1) (if (and (= (remainder (+ a b) N) ab-sum) (= (bitwise-xor a b) ab-xor)) (set-add ab-pairs (cons a b)) ab-pairs)))))))) (sum-xor-pairs-n 9 5) (require rackunit) (define (check-sum-xor-pairs n) (define (check-sum-xor-list ab-pairs ab-sum ab-xor) (define N (smallest-at-least-po2 ab-sum ab-xor)) (define (check-sum-xor a b) (check-eq? (remainder (+ a b) N) ab-sum) (check-eq? (bitwise-xor a b) ab-xor)) (let loop ((ab-pairs ab-pairs)) (if (set-empty? ab-pairs) #t (let ((p (set-first ab-pairs))) (check-sum-xor (car p) (cdr p)) (loop (set-rest ab-pairs)))))) (do ((ab-sum 0 (+ 1 ab-sum))) ((> ab-sum n) #t) (do ((ab-xor 0 (+ 1 ab-xor))) ((> ab-xor n) #t) (let ((ab-pairs-n (sum-xor-pairs-n ab-sum ab-xor))) (check-sum-xor-list ab-pairs-n ab-sum ab-xor) (set=? ab-pairs-n (sum-xor-pairs-nsq ab-sum ab-xor)))))) (check-sum-xor-pairs 100) (define (time-it f n iters) (define (run-test) (do ((ab-sum 0 (+ 1 ab-sum))) ((> ab-sum n) #t) (do ((ab-xor 0 (+ 1 ab-xor))) ((> ab-xor n) #t) (f ab-sum ab-xor)))) (let loop ((t 0) (i 0)) (if (= i iters) (inexact->exact (round (/ t iters))) (let-values (((a b c d) (time-apply run-test '()))) (loop (+ t b) (+ i 1)))))) (let ((iters 3)) (do ((i 10 (+ 10 i))) ((> i 100) #t) (printf "sum-xor max: ~a, sum-xor-pairs-n: ~a, sum-xor-pairs-nsq: ~a\n" i (time-it sum-xor-pairs-n i iters) (time-it sum-xor-pairs-nsq i iters)))) $ mzscheme sxp.rkt (set '(2 . 7) '(3 . 6) '(10 . 15) '(11 . 14)) #t sum-xor max: 10, sum-xor-pairs-n: 3, sum-xor-pairs-nsq: 1 sum-xor max: 20, sum-xor-pairs-n: 7, sum-xor-pairs-nsq: 3 sum-xor max: 30, sum-xor-pairs-n: 17, sum-xor-pairs-nsq: 9 sum-xor max: 40, sum-xor-pairs-n: 32, sum-xor-pairs-nsq: 39 sum-xor max: 50, sum-xor-pairs-n: 52, sum-xor-pairs-nsq: 80 sum-xor max: 60, sum-xor-pairs-n: 73, sum-xor-pairs-nsq: 128 sum-xor max: 70, sum-xor-pairs-n: 107, sum-xor-pairs-nsq: 296 sum-xor max: 80, sum-xor-pairs-n: 145, sum-xor-pairs-nsq: 547 sum-xor max: 90, sum-xor-pairs-n: 188, sum-xor-pairs-nsq: 827 sum-xor max: 100, sum-xor-pairs-n: 236, sum-xor-pairs-nsq: 1128 #t $