$ cat sxp.rkt
#lang racket
(require racket/set)
(define (smallest-at-least-po2 a b)
; Return the smallest power of 2 that is at least the given
; values.
(define (bit-size i)
; Return the smallest number of bits required to represent i.
(if (< i 2)
1
(let ((n (inexact->exact (ceiling (/ (log i) (log 2))))))
(+ n (if (= (expt 2 n) i) 1 0)))))
(expt 2 (max (bit-size a) (bit-size b))))
; These functions treat numbers as n-bit values, where n is the
; smallest number of bits capable of representing the sum and xor
; values. They also ignore overflow, which means it's possible
; to have pair values that are larger than the sum value but
; still add to the sum value. Also, because xor and sum are
; commutative, the number pairs are created with the smaller
; value first (on the left, the car).
(define (sum-xor-pairs-n ab-sum ab-xor)
; Return a set of number pairs such that, for each pair, the
; sum of the two numbers equals ab-sum (ignoring overflow) and
; the xor of the two numbers equlas ab-xor.
; This function does work proportional to the power of the
; larger of the smallest number of bits needed to represent the
; sum and xor values, which is roughly linear in the larger of
; the sum and xor values.
; This function essentially implements the truth tables sum and
; xor:
;
; Ci a b Co s x
; 0 0 0 0 0 0
; 0 0 1 0 1 1
; 0 1 0 0 1 1
; 0 1 1 1 0 0
; 1 0 0 0 1 0
; 1 0 1 1 0 1
; 1 1 0 1 0 1
; 1 1 1 1 1 0
;
; Ci the carry-in bit
; a a bit from one of the pair values
; b the corresponding bit from the other pair value
; Co the carry-out from Ci + a + b
; s Ci + a + b
; x a xor b
;
; Interesting things to note about this table:
;
; The s and x bits are the same when Ci = 0.
; The s and x bits are different when Ci = 1.
; xor ignores Ci (and Co).
(define (add-msb a-msb b-msb ab-pairs)
; Add the given most-significant bits to the given set of
; number pairs; return the new set.
(for/set ([v ab-pairs])
(vector
(cons a-msb (vector-ref v 0))
(cons b-msb (vector-ref v 1)))))
(define (vector->pair ab-pairs)
; Return a set of number pairs, where each number corresponds
; to a bit list in the given pairs and each pair corresponds
; to a vector in the given set. The smaller value appears
; first in the pair.
(define (implode bit-list)
; Return the number equivalent to the given bit list (msb on
; the left).
(let loop ((n 0) (bit-list bit-list))
(if (null? bit-list)
n
(loop (+ (* n 2) (car bit-list)) (cdr bit-list)))))
(for/set ((v ab-pairs))
(let ((a (implode (vector-ref v 0)))
(b (implode (vector-ref v 1))))
(cons (min a b) (max a b)))))
(define (oops emsg)
(raise-arguments-error
'sum-xor-pairs-n "some unfathomable error"))
(vector->pair
(let loop
((ab-sum ab-sum)
(ab-xor ab-xor)
(carry-in 0)
(ab-pairs (if (and (= ab-sum 0) (= ab-xor 0))
(set #((0) (0)) #((1) (1)))
(set #(() ())))))
(if (and (= ab-sum 0) (= ab-xor 0))
ab-pairs
(let
((sum-bit (remainder ab-sum 2))
(xor-bit (remainder ab-xor 2))
(ab-sum (quotient ab-sum 2))
(ab-xor (quotient ab-xor 2)))
(cond
((= carry-in 0)
(cond
((and (= sum-bit 0) (= xor-bit 0))
(set-union
(loop ab-sum ab-xor 0 (add-msb 0 0 ab-pairs))
(loop ab-sum ab-xor 1 (add-msb 1 1 ab-pairs))))
((and (= sum-bit 1) (= xor-bit 1))
(set-union
(loop ab-sum ab-xor 0 (add-msb 0 1 ab-pairs))
(loop ab-sum ab-xor 0 (add-msb 1 0 ab-pairs))))
((not (= sum-bit xor-bit))
; If the carry-in's zero, the sum of the two
; bits (ignoring carry-out) must equal the xor
; of the two bits. Because that's not the
; case, there can be no solutions down this
; branch.
(set))
(#t
(oops))))
((= carry-in 1)
(cond
((and (= sum-bit 1) (= xor-bit 0))
(set-union
(loop ab-sum ab-xor 0 (add-msb 0 0 ab-pairs))
(loop ab-sum ab-xor 1 (add-msb 1 1 ab-pairs))))
((and (= sum-bit 0) (= xor-bit 1))
(set-union
(loop ab-sum ab-xor 1 (add-msb 0 1 ab-pairs))
(loop ab-sum ab-xor 1 (add-msb 1 0 ab-pairs))))
((= sum-bit xor-bit)
; If the carry-in's one, the sum of the two
; bits (ignoring carry-out) cannot equal the
; xor of the two bits. Because that's not the
; case, there can be no solutions down this
; branch.
(set))
(#t
(oops))))
(#t
(oops))))))))
(define (sum-xor-pairs-nsq ab-sum ab-xor)
; Return a number-pair set such that, for each pair, the sum of
; the two numbers equals ab-sum (ignoring overflow) and the xor
; of the two numbers equlas ab-xor.
; This function does work proportinal to ab-sum*ab-xor
; (a.k.a. n-squared).
(let ((N (smallest-at-least-po2 ab-sum ab-xor)))
(let outer-loop ((a 0) (ab-pairs (set)))
(if (= a N)
ab-pairs
(let inner-loop ((b a) (ab-pairs ab-pairs))
(if (= b N)
(outer-loop (+ a 1) ab-pairs)
(inner-loop
(+ b 1) (if (and (= (remainder (+ a b) N) ab-sum)
(= (bitwise-xor a b) ab-xor))
(set-add ab-pairs (cons a b))
ab-pairs))))))))
(sum-xor-pairs-n 9 5)
(require rackunit)
(define (check-sum-xor-pairs n)
(define (check-sum-xor-list ab-pairs ab-sum ab-xor)
(define N (smallest-at-least-po2 ab-sum ab-xor))
(define (check-sum-xor a b)
(check-eq? (remainder (+ a b) N) ab-sum)
(check-eq? (bitwise-xor a b) ab-xor))
(let loop ((ab-pairs ab-pairs))
(if (set-empty? ab-pairs)
#t
(let ((p (set-first ab-pairs)))
(check-sum-xor (car p) (cdr p))
(loop (set-rest ab-pairs))))))
(do ((ab-sum 0 (+ 1 ab-sum))) ((> ab-sum n) #t)
(do ((ab-xor 0 (+ 1 ab-xor))) ((> ab-xor n) #t)
(let ((ab-pairs-n (sum-xor-pairs-n ab-sum ab-xor)))
(check-sum-xor-list ab-pairs-n ab-sum ab-xor)
(set=? ab-pairs-n (sum-xor-pairs-nsq ab-sum ab-xor))))))
(check-sum-xor-pairs 100)
(define (time-it f n iters)
(define (run-test)
(do ((ab-sum 0 (+ 1 ab-sum))) ((> ab-sum n) #t)
(do ((ab-xor 0 (+ 1 ab-xor))) ((> ab-xor n) #t)
(f ab-sum ab-xor))))
(let loop ((t 0) (i 0))
(if (= i iters)
(inexact->exact (round (/ t iters)))
(let-values
(((a b c d) (time-apply run-test '())))
(loop (+ t b) (+ i 1))))))
(let ((iters 3))
(do ((i 10 (+ 10 i))) ((> i 100) #t)
(printf
"sum-xor max: ~a, sum-xor-pairs-n: ~a, sum-xor-pairs-nsq: ~a\n"
i
(time-it sum-xor-pairs-n i iters)
(time-it sum-xor-pairs-nsq i iters))))
$ mzscheme sxp.rkt
(set '(2 . 7) '(3 . 6) '(10 . 15) '(11 . 14))
#t
sum-xor max: 10, sum-xor-pairs-n: 3, sum-xor-pairs-nsq: 1
sum-xor max: 20, sum-xor-pairs-n: 7, sum-xor-pairs-nsq: 3
sum-xor max: 30, sum-xor-pairs-n: 17, sum-xor-pairs-nsq: 9
sum-xor max: 40, sum-xor-pairs-n: 32, sum-xor-pairs-nsq: 39
sum-xor max: 50, sum-xor-pairs-n: 52, sum-xor-pairs-nsq: 80
sum-xor max: 60, sum-xor-pairs-n: 73, sum-xor-pairs-nsq: 128
sum-xor max: 70, sum-xor-pairs-n: 107, sum-xor-pairs-nsq: 296
sum-xor max: 80, sum-xor-pairs-n: 145, sum-xor-pairs-nsq: 547
sum-xor max: 90, sum-xor-pairs-n: 188, sum-xor-pairs-nsq: 827
sum-xor max: 100, sum-xor-pairs-n: 236, sum-xor-pairs-nsq: 1128
#t
$