All Pair

Using KGraph for All Pair Similarity Search

Here's how to use KGraph to implement similar functionality of the Google all-pair similarity search algorithm. That is, to find out all pairs of objects with similarity above a given threshold. KGraph is not the best algorithm to solve all-pair for text data -- when threshold is high, Google's code runs orders of magnitude faster. But it should be competitive for dense high-dimensional data. Again, only use the code below when you want to work on non-sparse data.

The trick is to secretly save all pairs of highly similar objects in the distance computation function of IndexOracle. Because index construction is parallelized, a spinlock is used to synchronize global data access.

// Copyright 2007 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// ---
// A simple all-similar-pairs algorithm for binary vector input.
// ---
// Author: Roberto Bayardo
// Modified by Wei Dong.

#include <stdio.h>
#include <stdlib.h>
#include <time.h>

#include <mutex>
#include <iostream>
#include <memory>
#include <vector>
#include <utility>
#include <tuple>
#include <algorithm>
#include <kgraph.h>
#include "data-source-iterator.h"
#include "boost/smart_ptr/detail/spinlock.hpp"

/******************* NEW CODE **********************/
typedef boost::detail::spinlock Lock;
typedef std::lock_guard<Lock> LockGuard;
Lock lock;
std::vector<std::tuple<uint32_t,uint32_t,float>> results;
class Oracle: public kgraph::IndexOracle {
    struct Object {
        uint32_t id;
        std::vector<uint32_t> data;
    };
    std::vector<Object> data;
    float threshold;
public:
    Oracle (DataSourceIterator *it, float th): threshold(th) {
        Object obj;
        while (it->Next(&obj.id, &obj.data)) {
            data.push_back(obj);
        }
        std::cerr << size() << " objects loaded." << std::endl;
    }
    virtual unsigned size () const {
        return data.size();
    }
    virtual float operator () (unsigned i, unsigned j) const {
        auto const &v1 = data[i].data;
        auto const &v2 = data[j].data;
        unsigned p1 = 0, p2 = 0;
        unsigned shared = 0;
        for (;;) {
            if (p1 >= v1.size()) break;
            if (p2 >= v2.size()) break;
            if (v1[p1] < v2[p2]) {
                ++p1;
            }
            else if (v1[p1] > v2[p2]) {
                ++p2;
            }
            else {
                ++shared;
                ++p1;
                ++p2;
            }
        }
        float score = float(shared) / sqrt(float(v1.size() * v2.size()));
        if (score >= threshold) {
            unsigned a = data[i].id;
            unsigned b = data[j].id;
            if (a < b) std::swap(a,b);
            LockGuard guard(lock);
            results.push_back(std::make_tuple(a, b, score));
        }
        return -score;
    }

    void dump () {
        std::sort(results.begin(), results.end());
        results.resize(std::unique(results.begin(), results.end()) - results.begin());
        for (auto const &t: results) {
            std::cout << std::get<0>(t) << ',' << std::get<1>(t) << ',' << std::get<2>(t) << std::endl;
        }
    }
};
/******************* END OF NEW CODE ****************/

int main(int argc, char** argv) {
  time_t start_time;
  time(&start_time);

  // Verify input arguments.
  if (argc != 3) {
    std::cerr << "ERROR: Usage is: ./ap <sim_threshold> <dataset_path>\n";
    return 1;
  }
  const double threshold = strtod(argv[1], 0);
  if (threshold <= 0.0 || threshold > 1.0) {
    std::cerr << "ERROR: The first argument should be a similarity "
              << "threshhold with range (0.0-1.0]\n";
    return 2;
  }
  std::cerr << "; User specified similarity threshold: "
            << threshold << std::endl;

  {
    std::auto_ptr<DataSourceIterator> data(DataSourceIterator::Get(argv[2]));
    if (!data.get())
      return 3;
    /******************* NEW CODE **********************/
    Oracle oracle(data.get(), threshold);
    kgraph::KGraph *run = kgraph::KGraph::create();
    kgraph::KGraph::IndexParams params;
    run->build(oracle, params, NULL);
    oracle.dump();
    delete run;
    /******************* END OF NEW CODE ****************/
  }

  time_t end_time;
  time(&end_time);
  std::cerr << "; Total running time: " << (end_time - start_time)
            << " seconds" << std::endl;

  return 0;
}