[ create a new paste ] login | about

Link: http://codepad.org/Zy7elORD    [ raw code | fork ]

C++, pasted on Jan 22:
#include <bits/stdc++.h>

#define sz(z) (int)z.size()
#define fo(i,a,b) for (auto (i) = (a); (i) < (b); (i)++)
#define mp make_pair
#define pb push_back

using namespace std;

#define DEBUGz

#ifdef DEBUG
#define D(x...) printf(x)
#else
#define D(x...) 
#endif

typedef long long ll;
typedef pair<int,int> ii;

struct ns {
	// t for tot
	int sum, v, pr, sml, big;
	ns *l, *r, *p;
	void upd () {
		sum = (l?l->sum:0)+(r?r->sum:0)+1;
		sml = (l?l->sml:v);
		big = (r?r->big:v);
	}
};

ns noS[9111111];
int counterS;

struct nr {
	int s, e;
	nr *l, *r;
	ns *h;
};

nr noR[8111111];
int counterR;

int stuff[111111];	

nr* newR (int s, int e) {
	nr* cur = &noR[counterR++];
	cur->s = s; cur->e = e;
	cur->l = cur->r = NULL;
	cur->h = NULL;
	return cur;
}

ns* newS (int val) {
	ns* cur = &noS[counterS++];
	cur->v = val; cur->pr = rand();
	cur->sml = cur->big = val;
	cur->sum = 1;
	cur->l = cur->r = cur->p = NULL;
	return cur;
}

void rotateR (ns* x) {
	ns* p = x->p;
	ns* g = p?p->p:NULL;
	p->r = x->l;
	if (p->r) p->r->p = p;
	x->l = p;
	p->p = x;
	if (g) {
		if (g->l == p) g->l = x;
		else g->r = x;
	}
	x->p = g;
	x->upd(); p->upd();
}

void rotateL (ns* x) {
	ns* p = x->p;
	ns* g = p?p->p:NULL;
	p->l = x->r;
	if (p->l) p->l->p = p;
	p->p = x;
	x->r = p;
	if (g) {
		if (g->l == p) g->l = x;
		else g->r = x;
	}
	x->p = g;
	x->upd(); p->upd();
}

int dir (ns* x) {
	return (x->p->l == x);
}

void rotate (ns* x) {
	dir(x)?rotateL(x):rotateR(x);
}

ns* getRoot (ns* x) {
	while (x != NULL && x->p != NULL) {
		x->upd();
		x = x->p;
	}
	x->upd();
	return x;
}

ns* fix (ns* cur) {
	while (cur->p != NULL && cur->pr > cur->p->pr) {
		cur->upd();
		rotate(cur);
	}
	return getRoot(cur);
}

int getSum(ns* x) {
	if (x == NULL) {
		D("not good!\n"); return 0;
	}
	return x->sum;
}

ns* insert (ns* cur, int val) {
	if (val >= cur->v) {
		if (cur->r == NULL) {
			cur->r = newS(val);
			cur->r->p = cur;
			cur->upd();
			D("cur has %d elements\n", getSum(cur));
			return fix(cur->r);
		}
		return insert(cur->r, val);
	}
	if (cur->l == NULL) {
		cur->l = newS(val);
		cur->l->p = cur;
		cur->upd();
			D("cur has %d elements\n", getSum(cur));
		return fix(cur->l);
	}
	return insert(cur->l, val);
}

nr* build (int s, int e) {
	nr* cur = newR(s, e);
	cur->h = newS(stuff[s]);
	for (int i = s+1; i <= e; i++) {
		cur->h = insert(cur->h, stuff[i]);
	}
	D("building %d %d!\n", s, e);
	D("cur->h has %d elements!\n",getSum(cur->h));
	if (s == e) return cur;
	int m = s + e;
	m /= 2;
	cur->l = build(s,m);
	cur->r = build(m+1, e);
	return cur;
}

ns* find (ns* x, int v) {
	if (x->v == v) return x;	
	return (x->v>v)?find(x->l,v):find(x->r,v);
}

ns* del (ns* x) {
	while (!(x->l == NULL && x->r == NULL)) {
		if (x->l == NULL) {
			rotate(x->r);
		} else if (x->r == NULL) {
			rotate(x->l);
		} else { 
			if (x->l->pr > x->r->pr) {
				rotate(x->l);
			} else {
				rotate(x->r);
			}
		}
	}
	if (x->p->l == x) x->p->l = NULL;
	else x->p->r = NULL;
	return fix(x->p);
}

void update (nr* x, int p, int v) {
	D("updating %d %d in range (%d, %d) \n", p, v, x->s, x->e);
	x->h = insert(x->h, v);
	x->h = del(find(x->h, stuff[p]));
	if (x->s == x->e) return;
	int m = x->s + x->e;
	m /= 2;
	if (p <= m) update(x->l, p, v);
	else update(x->r, p, v);
}

int go (ns* x, int l, int r, int k) {
	D("going %d %d %d\n", l, r, k);
	if (l == r) {
		if (x->v > k) return l-1;
		return l;
	}
	int rs = x->r?x->r->sml:2e9;
	if (rs <= k) return go(x->r,l+(getSum(x->l))+1, r, k);
	if (x->v <= k) return l+getSum(x->l);
	return go(x->l, l, r-getSum(x->r)-1,k);
}

int qu (ns* x, int k) {
	if (x->sml > k) return getSum(x);
	return getSum(x) - go(x, 1, getSum(x), k);
}

int q (nr* x, int s, int e, int k) {
	if (!x) return 0;
	if (e < x->s || s > x->e) return 0;
	D("querying (%d,%d) %d %d %d\n", x->s, x->e, s, e, k);
	if (s <= x-> s && x->e <= e) return qu(x->h, k);
	return q(x->l,s,e,k)+q(x->r,s,e,k);
}

void print(nr* x) {
	if (!x) return;
	D("range from (%d,%d) has %d elements!\n", x->s, x->e, getSum(x->h));
	print(x->l);
	print(x->r);
}

int main() {
	srand(420);
	int n;
	scanf("%d",&n);
	fo(i,0,n) scanf("%d ", stuff+i+1);
	int numq;
	nr* root = build(1,n);
	D("finished building!\n");
	scanf("%d",&numq);
	fo(i,0,numq) {
		int a, b, c;
		scanf("%d",&a);
		if (a == 0) {
			scanf("%d %d", &a, &b);
			update(root, a, b);
			stuff[a] = b;
		} else {
			scanf("%d %d %d", &a, &b, &c);
			printf("%d\n", q(root, a, b, c));
		}
	}
	return 0;
}


Create a new paste based on this one


Comments: