vector2個の積集合を取る


このエントリーをはてなブックマークに追加

C++ではset_intersection()を使うと積集合をとることができる。ただし、ちょっと使いにくい。後半で紹介する、inserterと組み合わせる方法の方がよさそう。

準備

set_intersection()を使って2つのvectorの積集合を求める前に、それぞれのvectorをあらかじめソートしておく必要がある。

inserterを使わない方法

まずコードを示す。

#include <iostream>
#include <vector>
#include <algorithm>
#include <iterator>

using namespace std;

void printVector( vector<int> &vec ){
    cout << "size = " << vec.size() << endl;
    for(int i = 0; i < vec.size(); ++i){
        cout << vec[i] << ", ";
    }
    cout << endl << endl;
}

int main(){
    int _vec1[] = {2, 3, 5, 7, 11, 13};
    int _vec2[] = {0, 1, 2, 3, 4, 5, 6};
    
    vector<int> vec1(_vec1, &_vec1[sizeof(_vec1)/sizeof(_vec1[0])]);
    vector<int> vec2(_vec2, &_vec2[sizeof(_vec2)/sizeof(_vec2[0])]);
    printVector(vec1); // 2, 3, 5, 7, 11, 13, 
    printVector(vec2); // 0, 1, 2, 3, 4, 5, 6,

    // vec1 and vec2 must be sorted
    sort(vec1.begin(), vec1.end());
    sort(vec2.begin(), vec2.end());

    vector<int> intersection(vec1.size()+vec2.size());

    vector<int>::iterator it
        = set_intersection(vec1.begin(), vec1.end()
                           , vec2.begin(), vec2.end()
                           , intersection.begin());

    cout << "intersection has " << int(it - intersection.begin()) << " elements.\n";
    printVector(intersection); // 2, 3, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    
    return 0;
}

コード中の変数intersectionは、積集合をとった結果を格納するためのvectorである。ここで注意するのは、あらかじめプログラマが結果格納のための領域を確保しておかなければならないところである。仮に、コード中で

vector<int> intersection;

としてしまうと、segmentation faultで実行時に死ぬ。

また、積集合の要素の数を知るには、set_intersection()が返すイテレータと、intersection.begin()が返すイテレータを比べなければならない。

inserterを使う方法(おすすめ)

こちらの方法がおすすめ。こちらの方法では、set_intersection()の第5引数に、inserter()を用いている。こうすると、あらかじめ変数intersectionに領域を確保しておく必要がなくなるらしい。また、もはやset_intersection()の戻り値をチェックする必要もない。

    vector<int> intersection;
    // vector<int>::iterator it = // iterator is not needed anymore
    set_intersection(vec1.begin(), vec1.end()
                     , vec2.begin(), vec2.end()
                     , inserter(intersection, intersection.end()));

上記inserterの部分は、

back_inserter(intersection)

でもよい。

参考