# itertools.intersect?

Discussion in 'Python' started by David Wilson, Jun 10, 2009.

1. ### David WilsonGuest

Hi,

During a fun coding session yesterday, I came across a problem that I
thought was already solved by itertools, but on investigation it seems
it isn't.

The problem is simple: given one or more ordered sequences, return
only the objects that appear in each sequence, without reading the
whole set into memory. This is basically an SQL many-many join.

I thought it could be accomplished through recursively embedded
generators, but that approach failed in the end. After posting the
question to Stack Overflow[0], Martin Geisler proposed a wonderfully
succinct and reusable solution (see below, or pretty printed at the
Stack Overflow URL).

It is my opinion that this particular implementation is a wonderful
and incredibly valid use of iterators, and something that could be
reused by others, certainly least not myself again in the future. With
that in mind I thought it, or something very similar, would be a great

My question then is, are there better approaches to this? The heapq-
based solution at the Stack Overflow page is potentially more useful
still, for its ability to operate on orderless sequences, but in that
case, it might be better to simply listify each sequence, and sort it
before passing to the ordered-only functions.

Thanks,

David.

Stack Overflow page here:

http://stackoverflow.com/questions/969709/joining-a-set-of-ordered-integer-yielding-python-iterators

Sweet solution:

import operator

def intersect(sequences):
"""Compute intersection of sequences of increasing integers.
... [2, 100, 101, 322, 1221],
... [100, 142, 322, 956, 1222]]))
[100, 322]

"""
iterators = [iter(seq) for seq in sequences]
last = [iterator.next() for iterator in iterators]
indices = range(len(iterators))
while True:
# The while loop stops when StopIteration is raised. The
# exception will also stop the iteration by our caller.
if reduce(operator.and_, [l == last[0] for l in last]):
# All iterators contain last[0]
yield last[0]
last = [iterator.next() for iterator in iterators]

# Now go over the iterators once and advance them as
# necessary. To stop as soon as the smallest iterator we
# advance each iterator only once per loop iteration.
for i in indices[:-1]:
if last < last[i+1]:
last = iterators.next()
if last > last[i+1]:
last[i+1] = iterators[i+1].next()

David Wilson, Jun 10, 2009

2. ### Jack DiederichGuest

[snip]

Here's my version;  keep a list of (curr_val, iterator) tuples and
operate on those.

def intersect(seqs):
iter_pairs = [(it.next(), it) for (it) in map(iter, seqs)]
while True:
min_val = min(iter_pairs)[0]
max_val = max(iter_pairs)[0]
if min_val == max_val:
yield min_val
max_val += 1 # everybody advances
for i, (val, it) in enumerate(iter_pairs):
if val < max_val:
iter_pairs = (it.next(), it)
# end while True

Interestingly you don't need to explicitly catch StopIteration and
return because only the top level is a generator.  So both lines where
it.next() are called will potentially end the loop.
I also tried using a defaultdict(list) as the main structure; it
worked but was uglier by far { curr_val => [it1, it2, ..]} with dels
and appends.

-Jack

ps, woops, I forgot to hit reply all the first time.

Jack Diederich, Jun 11, 2009

3. ### MensanatorGuest

Why not use SQL?

import sqlite3
con = sqlite3.connect(":memory:")
cur = con.cursor()
cur.executescript("""
create table test1(p INTEGER);
""")
cur.executescript("""
create table test2(q INTEGER);
""")
cur.executescript("""
create table test3(r INTEGER);
""")

for t in ((1,),(100,),(142,),(322,),(12312,)):
cur.execute('insert into test1 values (?)', t)
for t in ((2,),(100,),(101,),(322,),(1221,)):
cur.execute('insert into test2 values (?)', t)
for t in ((100,),(142,),(322,),(956,),(1222,)):
cur.execute('insert into test3 values (?)', t)

cur.execute("""
SELECT p
FROM (test1 INNER JOIN test2 ON p = q)
INNER JOIN test3 ON p = r;
""")

sqlintersect = cur.fetchall()

for i in sqlintersect:
print i[0],

print

##
## 100 322
##

Mensanator, Jun 11, 2009
4. ### Chris RebertGuest

Agreed. I seem to recall the last person asking for such a function
wanted to use it to combine SQL results.

Cheers,
Chris

Chris Rebert, Jun 11, 2009
5. ### David M. WilsonGuest

My original use case was a full text indexer. Here's the code:

Let me invert the question and ask: why would I want to use SQL for
this? Or in my own words: what kind of girly-code requires an RDBMS
just to join some sequences? =)

Given that Google reports 14.132 billion occurrences of "the" on the
English web, which is about right, given that some estimate the
English web at ~15 billion documents, or about 33.8 bits to uniquely
identify each document, let's assume we use a 64bit integer, that's
theoretically 111.7GiB of data loaded into SQL just for a single word.

Introducing SQL quickly results in artificially requiring a database
system, when a 15 line function would have sufficed. It also restricts
how I store my data, and prevents, say, using a columnar, variable
length, or delta encoding on my sequence of document IDs, which would
massively improve the storage footprint (say, saving 48-56 bits per
element). I'll avoid mention of the performance aspects altogether.

"What the hell are you thinking",

David

David M. Wilson, Jun 11, 2009
6. ### David M. WilsonGuest

This version is a lot easier to understand. The implicit StopIteration
is a double-edged sword for readability, but I like it.

David

David M. Wilson, Jun 11, 2009
7. ### David M. WilsonGuest

I found my answer: Python 2.6 introduces heap.merge(), which is
designed exactly for this.

Thanks all,

David.

David M. Wilson, Jun 11, 2009
8. ### Carl BanksGuest

Well if the source data is already in a sql database that would make
most sense, but if it isn't and since the iterator is pretty simple
I'd say just go with that.

Unless you have some other things happening downstream that would also
benefit from the source data being in a database, or something.

Carl Banks

Carl Banks, Jun 11, 2009
9. ### Jack DiederichGuest

Thanks, I knew Raymond added something like that but I couldn't find
it in itertools.
That said .. it doesn't help. Aside, heapq.merge fits better in
itertools (it uses heaps internally but doesn't require them to be
passed in). The other function that almost helps is
itertools.groupby() and it doesn't return an iterator so is an odd fit
for itertools.

More specifically (and less curmudgeonly) heap.merge doesn't help for
this particular case because you can't tell where the merged values
came from. You want all the iterators to yield the same thing at once
but heapq.merge muddles them all together (but in an orderly way!).
Unless I'm reading your tokenizer func wrong it can yield the same
value many times in a row. If that happens you don't know if four
"The"s are once each from four iterators or four times from one.

All that said your problem is an edge case so I'm happy to say the ten
line composite functions that we've been trading can do what you want
to do and in clear prose. The stdlib isn't meant to have a one liner
for everything.

-Jack

Jack Diederich, Jun 11, 2009
10. ### Terry ReedyGuest

David is looking to intersect sorted lists of document numbers with
duplicates removed in order to find documents that contain worda and
wordb and wordc ... . But you are right that duplicate are a possible
fly in the ointment to be removed before merging.

Terry Reedy, Jun 11, 2009
11. ### Arnaud DelobelleGuest

As it is a nice little problem I tried to find a solution today. FWIW,
here it is (tested extensively only on the example below :

def intersect(iterables):
nexts = [iter(iterable).next for iterable in iterables]
v = [next() for next in nexts]
while True:
for i in xrange(1, len(v)):
while v[0] > v:
v = nexts()
if v[0] < v: break
else:
yield v[0]
v[0] = nexts[0]()
.... [2, 100, 101, 322, 1221],
.... [100, 142, 322, 956, 1222]]))
[100, 322]

Arnaud Delobelle, Jun 11, 2009
12. ### Jack DiederichGuest

Ah, in that case the heap.merge solution is both useful and succinct:

import heapq
import itertools
def intersect(its):
source = heapq.merge(*its)
while True:
sames = [source.next()]
sames.extend(itertools.takewhile(lambda v:v == sames[0], source))
if len(sames) == len(its):
yield sames[0]
return

-Jack

Jack Diederich, Jun 11, 2009
13. ### MensanatorGuest

Removing the duplicates could be a big problem.

With SQL, the duplicates need not have to be removed.
All I have to do is change "SELECT" to "SELECT DISTINCT"
to change

100 100 100 322 322 322 322 322 322 322 322

into

100 322

Mensanator, Jun 11, 2009
14. ### Raymond HettingerGuest

[David Wilson]
FWIW, this is equivalent to the Welfare Crook problem in David Gries
book, The Science of Programming, http://tinyurl.com/mzoqk4 .

Translated into Python, David Gries' solution looks like this:

def intersect(f, g, h):
i = j = k = 0
try:
while True:
if f < g[j]:
i += 1
elif g[j] < h[k]:
j += 1
elif h[k] < f:
k += 1
else:
print(f)
i += 1
except IndexError:
pass

streams = [sorted(sample(range(50), 30)) for i in range(3)]
for s in streams:
print(s)
intersect(*streams)

Raymond

Raymond Hettinger, Jun 15, 2009
15. ### Andrew HenshawGuest

Here's my translation of your code to support variable number of streams:

def intersect(*s):
num_streams = len(s)
indices = [0]*num_streams
try:
while True:
for i in range(num_streams):
j = (i + 1) % num_streams
if s[indices] < s[j][indices[j]]:
indices += 1
break
else:
print(s[0][indices[0]])
indices[0] += 1
except IndexError:
pass

Andrew Henshaw, Jun 15, 2009
16. ### Arnaud DelobelleGuest

I posted this solution earlier on:

def intersect(iterables):
nexts = [iter(iterable).next for iterable in iterables]
v = [next() for next in nexts]
while True:
for i in xrange(1, len(v)):
while v[0] > v:
v = nexts()
if v[0] < v: break
else:
yield v[0]
v[0] = nexts[0]()

It's quite similar but not as clever as the solution proposed by
R. Hettinger insofar as it doesn't exploit the fact that if a, b, c are
members of a totally ordered set, then:

if a >= b >= c >= a then a = b = c.

However it can be easily modified to do so:

def intersect(iterables):
nexts = [iter(iterable).next for iterable in iterables]
v = [next() for next in nexts]
while True:
for i in xrange(-1, len(v)-1):
if v < v[i+1]:
v = nexts()
break
else:
yield v[0]
v[0] = nexts[0]()

I haven't really thought about it too much, but there may be cases where
the original version terminates faster (I guess when it is expected that
the intersection is empty).

Arnaud Delobelle, Jun 15, 2009
17. ### Francis CarrGuest

It is fairly easy to ignore duplicates in a sorted list:
<pre>
from itertools import groupby
def unique(ordered):
"""Yield the unique elements from a sorted iterable.
"""
for key,_ in groupby(ordered):
yield key
</pre>

Combining this with some ideas of others, we have a simple, complete
solution:
<pre>
def intersect(*ascendingSeqs):
"""Yield the intersection of zero or more ascending iterables.
"""
N=len(ascendingSeqs)
if N==0:
return

unq = [unique(s) for s in ascendingSeqs]
val = [u.next() for u in unq]
while True:
for i in range(N):
while val[i-1] > val:
val = unq.next()
if val[0]==val[-1]:
yield val[0]
val[-1] = unq[-1].next()
</pre>

This works with empty arg-lists; combinations of empty, infinite and
finite iterators; iterators with duplicate elements; etc. The only
requirement is that all iterators are sorted ascending.

-- FC

Francis Carr, Jun 22, 2009