To count number of quadruplets with sum = 0

N

n00m

http://www.spoj.pl/problems/SUMFOUR/

3
0 0 0 0
0 0 0 0
-1 -1 1 1
Answer for this input data is 33.

My solution for the problem is
======================================================================

import time
t = time.clock()

q,w,e,r,sch,h = [],[],[],[],0,{}

f = open("D:/m4000.txt","rt")

n = int(f.readline())

for o in range(n):
row = map(long, f.readline().split())
q.append(row[0])
w.append(row[1])
e.append(row[2])
r.append(row[3])

f.close()

for x in q:
for y in w:
if h.has_key(x+y):
h[x+y] += 1
else:
h[x+y] = 1

for x in e:
for y in r:
sch += h.get(-(x+y),0)

q,w,e,r,h = None,None,None,None,None

print sch
print time.clock() - t

===============================================================

Alas it gets "time limit exceeded".
On my home PC (1.6 GHz, 512 MB RAM) and for 4000 input rows it
executes ~1.5 min.
Any ideas to speed it up say 10 times? Or the problem only for C-like
langs?
 
M

Matimus

Any ideas to speed it up say 10 times? Or the problem only for C-like

I dout this will speed it up by a factor of 10, but in your solution
you are mapping the values in the input file to longs. The problem
statement states that the maximum value for any of the numbers is
2**28. I assume the reason for this is precisely to allow you to use
32-bit integers. So, you can safely use int instead, and it _might_
speed up a little.
 
S

Steven Bethard

n00m said:
http://www.spoj.pl/problems/SUMFOUR/

3
0 0 0 0
0 0 0 0
-1 -1 1 1
Answer for this input data is 33.

My solution for the problem is
======================================================================

import time
t = time.clock()

q,w,e,r,sch,h = [],[],[],[],0,{}

f = open("D:/m4000.txt","rt")

n = int(f.readline())

for o in range(n):
row = map(long, f.readline().split())
q.append(row[0])
w.append(row[1])
e.append(row[2])
r.append(row[3])

f.close()

for x in q:
for y in w:
if h.has_key(x+y):
h[x+y] += 1
else:
h[x+y] = 1

This won't help much, but you can rewrite the above as::

x_y = x + y
h[x_y] = h.get(x_y, 1)

Or if you're using Python 2.5, try::

h = collections.defaultdict(itertools.repeat(0).next)

...
for x in q:
for y in w:
h[x + y] += 1
...

Not likely to get you an order of magnitude though.
for x in e:
for y in r:
sch += h.get(-(x+y),0)

If you use the collections.defaultdict approach above, this becomes::

for x in e:
for y in r:
sch += h[-(x + y)]

Note that you should also probably put all your code into a function --
looking up function locals is quicker than looking up module globals.

STeVe
 
P

Paul Rubin

n00m said:
http://www.spoj.pl/problems/SUMFOUR/
3
0 0 0 0
0 0 0 0
-1 -1 1 1
Answer for this input data is 33.

f = open('input1')
npairs = int(f.readline())

quads = [map(int, f.readline().split()) for i in xrange(npairs)]
assert len(quads) == npairs

da = {}

for p in quads:
for q in quads:
z = p[2] + q[3]
da[z] = da.get(z,0) + 1

print sum([da.get(-(p[0]+q[1]), 0) for p in quads for q in quads])
 
P

Paul Rubin

Paul Rubin said:
print sum([da.get(-(p[0]+q[1]), 0) for p in quads for q in quads])

The above should say:

print sum(da.get(-(p[0]+q[1]), 0) for p in quads for q in quads)

I had to use a listcomp instead of a genexp for testing, since I'm
still using python 2.3, and I forgot to patch that before pasting to
the newsgroup. But the listcomp will burn a lot of memory when the
list is large.

I'd be interested in the running time of the above for your large data set.
Note that it still uses potentially O(n**2) space.
 
M

Michael Spencer

n00m said:
http://www.spoj.pl/problems/SUMFOUR/

3
0 0 0 0
0 0 0 0
-1 -1 1 1
Answer for this input data is 33.

My solution for the problem is
======================================================================

import time
t = time.clock()

q,w,e,r,sch,h = [],[],[],[],0,{}

f = open("D:/m4000.txt","rt")

n = int(f.readline())

for o in range(n):
row = map(long, f.readline().split())
q.append(row[0])
w.append(row[1])
e.append(row[2])
r.append(row[3])

f.close()

for x in q:
for y in w:
if h.has_key(x+y):
h[x+y] += 1
else:
h[x+y] = 1

for x in e:
for y in r:
sch += h.get(-(x+y),0)

q,w,e,r,h = None,None,None,None,None

print sch
print time.clock() - t

===============================================================

Alas it gets "time limit exceeded".
On my home PC (1.6 GHz, 512 MB RAM) and for 4000 input rows it
executes ~1.5 min.
Any ideas to speed it up say 10 times? Or the problem only for C-like
langs?
Perhaps a bit faster using slicing to get the lists and avoiding dict.get:

def sumfour(src):
l = map(int, src.split())
dct={}
s=0
A, B, C, D = l[1::4], l[2::4], l[3::4], l[4::4]
for a in A:
for b in B:
if a+b in dct:
dct[a+b] += 1
else:
dct[a+b] = 1
for c in C:
for d in D:
if -c-d in dct:
s+= dct[-c-d]
return s


if __name__ == '__main__':
import sys
print sumfour(sys.stdin.read())


Michael
 
N

n00m

Steven, I ran this:

import time, collections, itertools
t = time.clock()

q,w,e,r,sch = [],[],[],[],0

h = collections.defaultdict(itertools.repeat(0).next)

f = open("D:/m4000.txt","rt")

for o in range(int(f.readline())):
row = map(int, f.readline().split())
q.append(row[0])
w.append(row[1])
e.append(row[2])
r.append(row[3])

f.close()

for x in q:
for y in w:
h[x+y] += 1

for x in e:
for y in r:
sch += h[-(x + y)]

q,w,e,r,h = None,None,None,None,None

print sch
print time.clock() - t


========= and it almost froze my PC...
=== but it was faster than my code on input file with 1000 rows:
====== 2.00864607094s VS 3.14631077413s
 
N

n00m

Paul,

import time
t = time.clock()
f = open("D:/m4000.txt","rt")
npairs = int(f.readline())
quads = [map(int, f.readline().split()) for i in xrange(npairs)]
f.close()
da = {}
for p in quads:
for q in quads:
z = p[2] + q[3]
da[z] = da.get(z,0) + 1
print sum(da.get(-(p[0]+q[1]), 0) for p in quads for q in quads)
print time.clock() - t


Two first outputs is of above (your) code; next two - of my code:
0
68.9562762865 0
68.0813539151 0
62.5012896891
0
62.5030784639
 
P

Paul Rubin

n00m said:
h = collections.defaultdict(itertools.repeat(0).next)

Something wrong with
h = collections.defaultdict(int)
?????
for x in e:
for y in r:
sch += h[-(x + y)]

That scares me a little: I think it makes a new entry in h, for the
cases where -(x+y) is not already in h. You want:

for x in e:
for y in r:
sch += h.get(-(x+y), 0)
 
P

Paul Rubin

n00m said:
Two first outputs is of above (your) code; next two - of my code:

Yeah, I see now that we both used the same algorithm. At first glance
I thought you had done something much slower. The 10 second limit
they gave looks like they intended to do it about this way, but with a
compiled language. 68 seconds isn't so bad for 4000 entries in pure
CPython. Next thing to do I think is use psyco or pyrex.
 
N

n00m

Steve,
imo strangely enough but your suggestion to replace "if...: else:..."
with

x_y = x + y
h[x_y] = h.get(x_y, 1)

s=l=o=w=e=d the thing by ~1 sec.
 
P

Paul Rubin

n00m said:

I get 33190970 for the first set and 0 for the second set.

The first set only makes 38853 distinct dictionary entries, I guess
because the numbers are all fairly small so there's a lot of duplication.
The second set makes 8246860 distinct entries.

I got 128.42 sec runtime on the second set, on a 1.2 ghz Pentium M.
That was with the listcomp in the summation, which burned a lot of
extra memory, but I didn't bother writing out a summation loop.

I guess I can see a few other possible micro-optimizations but no
obvious algorithm improvements.
 
P

Paul McGuire

for o in range(int(f.readline())):
row = map(int, f.readline().split())
q.append(row[0])
w.append(row[1])
e.append(row[2])
r.append(row[3])

Does this help at all in reading in your data?

numlines = f.readline()
rows = [ map(int,f.readline().split()) for _ in range(numlines) ]
q,w,e,r = zip(rows)

-- Paul
 
M

Marc 'BlackJack' Rintsch

Something wrong with
h = collections.defaultdict(int)
?????

According to a post by Raymond Hettinger it's faster to use that iterator
instead of `int`.

Ciao,
Marc 'BlackJack' Rintsch
 
A

Anton Vredegoor

n00m said:
62.5030784639

Maybe this one could save a few seconds, it works best when there are
multiple occurrences of the same value.

A.

from time import time

def freq(L):
D = {}
for x in L:
D[x] = D.get(x,0)+1
return D

def test():
t = time()
f = file('m4000.txt')
f.readline()
L = []
for line in f:
L.append(map(int,line.split()))

q,w,e,r = map(freq,zip(*L))
sch,h = 0,{}
for xk,xv in q.iteritems():
for yk,yv in w.iteritems():
if h.has_key(xk+yk):
h[xk+yk] += xv*yv
else:
h[xk+yk] = xv*yv

for xk,xv in e.iteritems():
for yk,yv in r.iteritems():
if h.has_key(-(xk+yk)):
sch += h[-(xk+yk)]*xv*yv

print sch
print time()-t

if __name__=='__main__':
test()
 
S

Steven Bethard

Marc said:
According to a post by Raymond Hettinger it's faster to use that iterator
instead of `int`.

Yep. It's because the .next() method takes no arguments, while int()
takes varargs because you can do::

int('2')
int('2', 8)

Calling a no-args function is substantially faster than calling a
varargs function.

STeVe
 
P

Paul Rubin

Steven Bethard said:
Yep. It's because the .next() method takes no arguments, while int()
takes varargs because you can do:: ...

Heh, good point. Might be worth putting a special hack in defaultdict
to recognize the common case of defaultdict(int).
 
M

marek.rocki

My attempt uses a different approach: create two sorted arrays, n^2
elements each; and then iterate over them looking for matching
elements (only one pass is required). I managed to get 58,2250612857 s
on my 1,7 MHz machine. It requires numpy for decent performance,
though.

import numpy
import time

def parse_input():
al, bl, cl, dl = [], [], [], []
for i in xrange(int(raw_input())):
a, b, c, d = map(int, raw_input().split())
al.append(a)
bl.append(b)
cl.append(c)
dl.append(d)
return al, bl, cl, dl

def count_zero_sums(al, bl, cl, dl):
n = len(al) # Assume others are equal

# Construct al extended (every element is repeated n times)
ale = numpy.array(al).repeat(n)
del al
# Construct bl extended (whole array is repeated n times)
ble = numpy.zeros((n*n,), int)
for i in xrange(n): ble[i*n:(i+1)*n] = bl
del bl
# Construct abl - sorted list of all sums of a, b for a, b in al, bl
abl = numpy.sort(ale + ble)
del ale, ble

# Construct cl extended (every element is repeated n times)
cle = numpy.array(cl).repeat(n)
del cl
# Construct dl extended (whole array is repeated n times)
dle = numpy.zeros((n*n,), int)
for i in xrange(n): dle[i*n:(i+1)*n] = dl
del dl
# Construct cdl - sorted list of all negated sums of a, b for a, b in
cl, dl
cdl = numpy.sort(-(cle + dle))
del cle, dle

# Iterate over arrays, count matching elements
result = 0
i, j = 0, 0
n = n*n
try:
while True:
while abl < cdl[j]:
i += 1
while abl > cdl[j]:
j += 1
if abl == cdl[j]:
# Found matching sequences
ii = i + 1
while ii < n and abl[ii] == abl: ii += 1
jj = j + 1
while jj < n and cdl[jj] == cdl[j]: jj += 1
result += (ii - i)*(jj - j)
i, j = ii, jj
except IndexError:
pass

return result

t = time.clock()
print count_zero_sums(*parse_input())
print time.clock() - t
 

Ask a Question

Want to reply to this thread or ask your own question?

You'll need to choose a username for the site, which only take a couple of moments. After that, you can post your question and our members will help you out.

Ask a Question

Members online

No members online now.

Forum statistics

Threads
473,744
Messages
2,569,483
Members
44,901
Latest member
Noble71S45

Latest Threads

Top