Sunday, August 10, 2014

Optimizing Diversity Picking in the RDKit

The RDKit has an implementation of the MaxMin algorithm (Ashton, M. et. al., Quant. Struct.-Act. Relat., 21 (2002), 598-604) for diversity picking. This provides a computationally efficient method for picking diverse subsets of molecules from a larger set.

This post explores the performance of the different ways of invoking the algorithm and does a bit of looking at the scaling of the runtime with set size.

In [1]:
import numpy as np
from __future__ import print_function

from rdkit import Chem
from rdkit.Chem import Draw,rdMolDescriptors,AllChem
from rdkit.Chem.Draw import IPythonConsole
from rdkit import SimDivFilters,DataStructs

import gzip
%pylab inline

from rdkit import rdBase
print(rdBase.rdkitVersion)
import time
print(time.asctime())
Populating the interactive namespace from numpy and matplotlib
2014.09.1pre
Mon Aug 11 07:51:19 2014

Start by reading in a set of molecules from the Zinc Biogenic Compounds (formerly known as "Zinc Natural Products") subset and generating Morgan2 fingerprints for the diversity calculation. The compound structures are available here: http://zinc.docking.org/subsets/zbc

In [10]:
ms = []
inf = gzip.open('../data/znp.sdf.gz')
suppl = Chem.ForwardSDMolSupplier(inf)
while len(ms)<20000:
    m = next(suppl)
    AllChem.Compute2DCoords(m)
    if m is not None: ms.append(m)
    
In [4]:
fps = [rdMolDescriptors.GetMorganFingerprintAsBitVect(m,2) for m in ms]

There are two basic ways to use the MaxMinPicker

  1. Generate the distance matrix in advance and do the picking based on those pre-calculated distances
  2. Use a "lazy" form of the picker that generates distances using a caller-supplied function. For reasons of efficiency, this uses a cache internally so that distances don't have to be calculated more than once.

Here's a demonstration of using the first approach to pick 50 diverse instances from the first 1000 compounds in the set:

In [5]:
def dmat_sim(fps,ntopick):
    ds=[]
    for i in range(1,len(fps)):
         ds.extend(DataStructs.BulkTanimotoSimilarity(fps[i],fps[:i],returnDistance=True))
    mmp =SimDivFilters.MaxMinPicker()
    ids=mmp.Pick(numpy.array(ds),len(fps),ntopick)
    return ids

dmat_ids=dmat_sim(fps[:1000],50)

Note that there are some examples of using this approach floating around on the web that calculate the distance matrix in the wrong order. I've done my best to find and either fix or remove them, but there are no doubt still some bad ones out there.

In [7]:
print(list(dmat_ids)[:10])
[374, 636, 790, 497, 724, 366, 418, 587, 433, 537]

In [11]:
Draw.MolsToGridImage([ms[x] for x in dmat_ids[:10]],molsPerRow=4)
Out[11]:

Those are certainly diverse

Try the same thing using the lazy picker and a user-supplied function

In [13]:
def fn(i,j,fps=fps):
    return 1.-DataStructs.TanimotoSimilarity(fps[i],fps[j])
mmp =SimDivFilters.MaxMinPicker()
lazy_ids = mmp.LazyPick(fn,1000,50)

Make sure the results are the same

In [14]:
list(lazy_ids)==list(dmat_ids)
Out[14]:
True

When working with bit vectors, a third option that can be used that includes the distance calculation directly in the C++ call. I'll show below that this makes a difference in performance.

Note the LazyBitVectorPick call is a new addition to the RDKit. At the time of this writing it is only available in the github version of the code. It will appear in the 2014.09 release.

In [15]:
mmp =SimDivFilters.MaxMinPicker()
bv_ids = mmp.LazyBitVectorPick(fps[:1000],1000,50)
In [16]:
list(bv_ids)==list(dmat_ids)
Out[16]:
True

Now let's look at the relative performance of these approaches for a subset of the data.

Timing data was generated on a not-particularly-modern machine: a four-year old MacBook Pro

In [18]:
mmp =SimDivFilters.MaxMinPicker()
%timeit dmat_sim(fps[:2000],50)
1 loops, best of 3: 1.77 s per loop

In [19]:
mmp =SimDivFilters.MaxMinPicker()
%timeit mmp.LazyPick(fn,2000,50)
1 loops, best of 3: 1.01 s per loop

In [23]:
mmp =SimDivFilters.MaxMinPicker()
%timeit mmp.LazyBitVectorPick(fps,2000,50)
1 loops, best of 3: 458 ms per loop

So we have a clear winner: the LazyBitVector picker is considerably faster than the other alternatives.

Impact of the cache

The LazyBitVectorPicker is very runtime efficient, but the memory usage of the cache that it uses (about nPoints * nPicks * 30bytes) could theoretically cause problems for large data sets. The cache can be disabled, at the expense of substantially longer run time:

In [22]:
mmp =SimDivFilters.MaxMinPicker()
%timeit mmp.LazyBitVectorPick(fps,2000,50,useCache=False)
1 loops, best of 3: 1.37 s per loop

Impact of data set size

Let's look at the evolution of the run time with the number of compounds we're picking from

In [36]:
for sz in range(1,11):
    sz*=1000
    print("Doing: %1d"%(sz))
    mmp =SimDivFilters.MaxMinPicker()
    %timeit mmp.LazyBitVectorPick(fps,sz,50)
Doing: 1000
10 loops, best of 3: 186 ms per loop
Doing: 2000
1 loops, best of 3: 417 ms per loop
Doing: 3000
1 loops, best of 3: 688 ms per loop
Doing: 4000
1 loops, best of 3: 955 ms per loop
Doing: 5000
1 loops, best of 3: 1.21 s per loop
Doing: 6000
1 loops, best of 3: 1.49 s per loop
Doing: 7000
1 loops, best of 3: 1.8 s per loop
Doing: 8000
1 loops, best of 3: 2.34 s per loop
Doing: 9000
1 loops, best of 3: 2.22 s per loop
Doing: 10000
1 loops, best of 3: 2.55 s per loop

The runtime increases more or less linearly with the size of the set being picked from

Impact of the number of points being picked

In [37]:
for sz in range(1,11):
    sz*=20
    print("Doing: %1d"%(sz))
    mmp =SimDivFilters.MaxMinPicker()
    %timeit mmp.LazyBitVectorPick(fps,5000,sz)
Doing: 20
1 loops, best of 3: 201 ms per loop
Doing: 40
1 loops, best of 3: 737 ms per loop
Doing: 60
1 loops, best of 3: 1.69 s per loop
Doing: 80
1 loops, best of 3: 3.03 s per loop
Doing: 100
1 loops, best of 3: 4.78 s per loop
Doing: 120
1 loops, best of 3: 6.91 s per loop
Doing: 140
1 loops, best of 3: 9.64 s per loop
Doing: 160
1 loops, best of 3: 12.4 s per loop
Doing: 180
1 loops, best of 3: 15.8 s per loop
Doing: 200
1 loops, best of 3: 20 s per loop

That's more quadratic in the number of picks. Given the algorithm, which needs to compare each potential point to pick against all the points that have so-far been picked, this makes perfect sense.

No comments: