Python 3のsqlite3モジュールでSQLiteの練習

Python 3の標準ライブラリであるsqlite3を使って、SQLite と呼ばれるデータベースを触ってみるメモです。

基本

テーブルを作成

以下のコードでは、都道府県のデータを格納するprefecturesという名前のテーブルを定義します。このテーブルは、name(都道府県名), capital(都道府県庁所在地), population(人口), area(面積)という4つのカラムを持ちます。

import sqlite3

# データベースの保存先
database_store_path = './example.db'

# 最初にデータベースを表す Connection オブジェクトを作る
conn = sqlite3.connect(database_store_path)

# カーソルオブジェクトをつくる
c = conn.cursor()

# executeによりSQLコマンドを実行
# ここでは、'prefectures'という名前のtableを生成するSQLコマンドを実行
# ここでtext, integer, realというのは「カラム型」を意味する
c.execute('''CREATE TABLE prefectures
             (name text, capital text, population integer, area real)''')

データを挿入

次に、prefecturesテーブルにいくつかの都道府県のデータを挿入していきます。以下では、c.execute()を使ってデータを1個ずつ挿入する方法と、c.executemany()を使ってデータ列を一度に挿入する方法を示しています。

人口、面積のデータは英語版Wikipediaを参照しました。

# executeによりSQLコマンドを実行
# ここでは、'stocks'という名前のtableに、1個分のデータを挿入するSQLコマンドを実行
c.execute("INSERT INTO prefectures VALUES ('Kanagawa','Yokohama',9058094,2415.83)")
c.execute("INSERT INTO prefectures VALUES ('Tokyo','Tokyo',13929280,2194.07)")

# '?'というプレースホルダーを使って1個分のデータを挿入することもできる
c.execute("INSERT INTO prefectures VALUES (?,?,?,?)", ('Chiba', 'Chiba', 6278060, 5157.61))

# '?'というプレースホルダーを使って複数のデータを挿入することもできる
prefecture_list = [('Tochigi', 'Utsunomiya', 1943886, 6408.09),
                   ('Ibaraki', 'Mito', 2871199, 6097.19)]
c.executemany("INSERT INTO prefectures VALUES (?,?,?,?)", prefecture_list)

後始末

以下のコードによりDBを更新して保存します。

# 変更を保存 (commit)
conn.commit()

# 後始末
conn.close()

データの読み込み

さきほどexample.dbに保存したテーブルに格納されたデータを読み込んでみます。データの読み込み方は、SQLite入門 によると以下の3種類あります。実際は1番目と2番目を使いそうな気がします。

# 一度保存したDBを開く
conn = sqlite3.connect(database_store_path)
c = conn.cursor()

# 1. カーソルをイテレータ (iterator) として扱う
c.execute('SELECT * FROM prefectures')
for row in c:
    print(row)
print()
 
# 2. fetchallで結果リストを取得する
c.execute('SELECT * FROM prefectures')
for row in c.fetchall():
    print(row)
print()
 
# 3. fetchoneで1件ずつ取得する
c.execute('SELECT * FROM prefectures')
print(c.fetchone())  # 1番目のレコード
print(c.fetchone())  # 2番目のレコード
print(c.fetchone())  # 3番目のレコード
print(c.fetchone())  # 4番目のレコード
print(c.fetchone())  # 5番目のレコード
print(c.fetchone())  # 6番目のレコードは存在しないのでNoneが返る
print()

# 後始末
conn.close()

出力結果は以下のとおりです。

('Kanagawa', 'Yokohama', 9058094, 2415.83)
('Tokyo', 'Tokyo', 13929280, 2194.07)
('Chiba', 'Chiba', 6278060, 5157.61)
('Tochigi', 'Utsunomiya', 1943886, 6408.09)
('Ibaraki', 'Mito', 2871199, 6097.19)

('Kanagawa', 'Yokohama', 9058094, 2415.83)
('Tokyo', 'Tokyo', 13929280, 2194.07)
('Chiba', 'Chiba', 6278060, 5157.61)
('Tochigi', 'Utsunomiya', 1943886, 6408.09)
('Ibaraki', 'Mito', 2871199, 6097.19)

('Kanagawa', 'Yokohama', 9058094, 2415.83)
('Tokyo', 'Tokyo', 13929280, 2194.07)
('Chiba', 'Chiba', 6278060, 5157.61)
('Tochigi', 'Utsunomiya', 1943886, 6408.09)
('Ibaraki', 'Mito', 2871199, 6097.19)
None

応用

特定のカラムだけを取り出す

4つのカラムのうち、面積と人口のカラムのみを取り出すにはSELECTを以下のように使います。

# 一度保存したDBを開く
conn = sqlite3.connect(database_store_path)
c = conn.cursor()

# 特定の列だけを取り出す
c.execute('SELECT area,population FROM prefectures')
for row in c:
    print(row)

# 後始末
conn.close()

結果は以下です。

(2415.83, 9058094)
(2194.07, 13929280)
(5157.61, 6278060)
(6408.09, 1943886)
(6097.19, 2871199)

これ以降、DBを開いてカーソルを取得する部分と後始末の部分は省略します。変数cにはカーソルが入っていると思ってください。

特定の条件にあったデータだけを取り出す

人口が500万以上のデータだけを取り出すにはSELECTWHEREで条件を加えます。

# 人口が500万以上の行だけを取り出す
c.execute('SELECT * FROM prefectures WHERE population > 5000000')
for row in c:
    print(row)

結果は以下です。

('Kanagawa', 'Yokohama', 9058094, 2415.83)
('Tokyo', 'Tokyo', 13929280, 2194.07)
('Chiba', 'Chiba', 6278060, 5157.61)

データを特定の条件で並び替える

例えば人口が多い順にデータを並び替えて取り出すにはSELECTORDER BYで条件を加えます。最後のDESCというのが降順を意味していて、これをASCにするか省略するかすると昇順になります。

# 人口が多い順に取り出す
c.execute('SELECT * FROM prefectures ORDER BY population DESC')
for row in c:
    print(row)

結果は以下です。

('Tokyo', 'Tokyo', 13929280, 2194.07)
('Kanagawa', 'Yokohama', 9058094, 2415.83)
('Chiba', 'Chiba', 6278060, 5157.61)
('Ibaraki', 'Mito', 2871199, 6097.19)
('Tochigi', 'Utsunomiya', 1943886, 6408.09)

カラムを追加する

人口密度を表すカラムを追加してみます。SQL文でデータを追加・更新・削除する方法 (2/2)を参考にしました。

# 人口密度を表すカラムを追加
c.execute('ALTER TABLE prefectures ADD population_density float')
# 人口密度を挿入
c.execute('UPDATE prefectures SET population_density=population / area')

# 結果を表示
c.execute('SELECT * FROM prefectures')
for row in c:
    print(row)

結果は以下です。

('Kanagawa', 'Yokohama', 9058094, 2415.83, 3749.474921662534)
('Tokyo', 'Tokyo', 13929280, 2194.07, 6348.603280661054)
('Chiba', 'Chiba', 6278060, 5157.61, 1217.242094691146)
('Tochigi', 'Utsunomiya', 1943886, 6408.09, 303.3487357387302)
('Ibaraki', 'Mito', 2871199, 6097.19, 470.9052858775928)

素性が不明なデータベースを探索

データベースファイルが与えられたときに中身を探索する方法について記します。ちなみに、db.pyでデータベース探索 にあるように、db.pyというツールを使えばもっとかんたんにできるようですが、ここでは練習としてsqlite3モジュールのみで探索してみます。

まずはデータベースに格納されているテーブル一覧を取得する方法です。sqlite_masterに格納されている情報を使います。

# テーブル一覧を取得
# c.f. https://www.kite.com/python/answers/how-to-list-tables-using-sqlite3-in-python
c.execute("SELECT name FROM sqlite_master WHERE type='table'")
print(c.fetchall())

結果は以下です。

[('prefectures',)]

これでprefecturesという名前のテーブルがあることがわかりました。

次に、prefecturesテーブルのカラム名を取得します。やり方が複数ありそうで自信がありませんが、以下のコードで取得できました。

# テーブルのカラム名を取得
# https://stackoverflow.com/questions/947215/how-to-get-a-list-of-column-names-on-sqlite3-database")
c.execute("SELECT name FROM PRAGMA_TABLE_INFO('prefectures')")
print(c.fetchall())

結果は以下です。

[('name',), ('capital',), ('population',), ('area',), ('population_density',)]

TensorFlow / Kerasにおけるlogitsの意味

TensorFlowやKerasを使うと遭遇する logits という用語、ざっと検索してもすぐに意味が出てこなかったので書いておきます。なお、この用語は複数の意味を持つ単語なので注意願います。この記事ではあくまで TensorFlow / Keras でのlogitsの意味を説明しています。

logitsの定義

scikit-learnとTensorFlowによる実践機械学習 (Amazonアフィリエイトリンク) のp267を引用すると、logitsとは、「ソフトマックス活性化関数に通す前のニューラルネットワークの出力」です。

概念図

画像を3クラスに分類する分類器をニューラルネットで設計したときの典型的な例を以下に示します。ニューラルネットに1枚画像を入力し、  \left[4.0, 2.0, 1.0  \right] ^T という出力を得たとします。これはソフトマックス活性化関数に通す前の値であり、logitsと呼ばれるものです。

f:id:minus9d:20201025192350p:plain

このlogitsを、ソフトマックス活性化関数に与えて得られる出力は、確率を表すベクトルです。ソフトマックス活性化関数の定義に従って計算すると、 \left[0.844, 0.114, 0.042  \right] ^T が得られます。確率なので足すと1.0になっています。0.844の部分の計算はPythonだと以下のようになります。

>>> import math
>>> print(math.exp(4.0) / (math.exp(4.0) + math.exp(2.0) + math.exp(1.0)))
0.8437947344813395

この確率を表すベクトルと、正解ベクトル  \left[1.0, 0.0, 0.0  \right] ^T との間の誤差は、通常、2つの確率分布の距離を計算する関数である交差エントロピー誤差を用いて計算されます。交差エントロピー誤差の定義に従って計算すると、 0.170 になります。この計算はPythonだと以下のようになります。

>>> import math
>>> print((- 1.0 * math.log(0.8437947) - 0.0 * math.log(0.11419519) - 0.0 * mat
h.log(0.04201007)))
0.1698460604208926

tensorflowでの関数

さきほどの図と同じ計算をtensorflowの関数を使って行うと以下のようになります。

import tensorflow as tf


logits = [[4.0, 2.0, 1.0]]
labels = [[1.0, 0.0, 0.0]]

probabilites = tf.nn.softmax(logits)
print(probabilites)  # [[0.8437947  0.11419519 0.04201007]]

cross_entropy_loss = tf.losses.categorical_crossentropy(y_true=labels, y_pred=probabilites)
print(cross_entropy_loss)  # [0.16984598]

softmaxを経由せず、いきなり交差エントロピー誤差を計算する方法もあります。1つ目は、softmax_cross_entropy_with_logits()を使う方法です。

cross_entropy_loss1 = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
print(cross_entropy_loss1)  # [0.16984604]

2つ目は、categorical_crossentropy()の引数from_logitsをTrueにする方法です。

cross_entropy_loss2 = tf.losses.categorical_crossentropy(y_true=labels, y_pred=logits, from_logits=True)
print(cross_entropy_loss2)  # [0.16984604]

logitsと確率を間違えて使ってしまうと正しい値を計算できないので気をつけましょう。

Pythonのexceptionをprint()して何も出ない原因

Python 3にて以下のように例外オブジェクトをprint()するコード

try:
    raise ValueError
except Exception as e:
    print(e)

を実行しても、以下のように何も出力されません。




結論からいうと、上記のコードを以下のように修正すると

try:
    raise ValueError('value error has occured')
except Exception as e:
    print(e)

以下の様に出力されるようになります。

value error has occured

これはなぜかを説明していきます。

Pythonのraise文で投げることができるのは、例外クラスか、例外インスタンスのどちらかです( 参考 )。先程の例外クラスを投げた例

    raise ValueError

は、以下

    raise ValueError()

のように、引数無しで生成した例外インスタンスを投げたのと同じになります。

例外インスタンスをつくるときの引数には何を与えればよいでしょうか? 参考

raise Exception('spam', 'eggs')

という例がある通り、引数には任意のオブジェクトを任意の個数与えることができるようです。ただ、普通は、例外が起こった原因を究明するために役立つ情報を文字列で与えることが多いと思います。

ここで設定したオブジェクトは、except節で'as'を使って指定した変数の.argsの中に入っています。また、この変数を文字列として扱うと、.args を参照しなくても引数を直接印字できるようになっています。よって、下の例ではprint(e.args)としてもprint(e)としても同じ結果になります。

try:
    raise ValueError('spam', 'eggs')
except Exception as e:
    print(e.args)  # ('spam', 'eggs') と表示される
    print(e)  # ('spam', 'eggs') と表示される

ここまでの説明により、

try:
    raise ValueError
except Exception as e:
    print(e)

で何も表示されなかった理由がわかりました。

Visual Studio Codeの基本設定

共有設定を使う場合

2020年3月頃からプレビュー機能として提供されている、設定を共有する機能を有効にすることで、複数のマシンでVSCodeの設定を共有できます。参考: 「Visual Studio Code」が設定の同期に対応、Insider版でテスト中 - 窓の杜

  • 左下の歯車をクリック
  • Turn on Setting Sync...をクリック
  • Turn Onをクリック
  • 共有したいものを選んで Sign in & onをクリック
  • Microsoft / GItHub の好きな方でログイン

一から設定する場合

共有設定を使えず一から設定をやり直す必要があるときは、以下を順番に設定。

Settingsを設定

Visual Studio本体の設定はSettingsから行います。File -> Preferences -> Settingsからも開けるけど、Ctrl + , が便利。

拡張機能

まだまだ研究中です。

Snnipetを設定

競技プログラミングで役立ちます(まだほとんど何も登録してませんが)。File -> Preferences -> User Snippetsから好きな言語を選び、例えば以下のようなjsonを作成します。

{
    "is_prime": {
        "prefix": "is_prime",
        "body": [
            "bool is_prime(const ll n){",
            "    if (n <= 1){",
            "         return false;",
            "    }",
            "   for(ll i = 2; i*i <= n; ++i){",
            "       if (n % i == 0){",
            "           return false;",
            "       }",
            "   }",
            "   return true;",
            "}",
        ],
        "description": "judge n is prime or not"
    }
}
エディタで`is_prime`と打ってTabを押せば上記のスニペットを挿入できます。

## 参考文献

* [https://github.com/microsoft/vscode-tips-and-tricks:title]
    * 設定だけでなくショートカットキーの一覧が豊富
* [Visual Studio Code実践ガイド —— 最新コードエディタを使い倒すテクニック (Amazonのアフィリエイトリンク)](https://amzn.to/2JU8M36)

VS CodeでPython 3の競プロコードをデバッグ実行

AtCoderCodeforcesなどの競技プログラミングサイトでは、通常、標準入力から入力を受け取ります。例えばPythonで回答する場合、あなたの回答(例としてsolve.pyとします)は、入力が書かれたテキストファイルinput.txt

$ python solve.py < input.txt

のように与えられたとき、正しく出力することが求められます。

Visual Studio Codeでは<を使った入力や出力はdebugger/runtime specificということになっています (https://code.visualstudio.com/docs/editor/debugging#_redirect-inputoutput-tofrom-the-debug-target) 。私の理解が正しければ、Pythonを実行する処理系が<をうまく扱えない場合、Visual Studio Code<を使う諦めざるを得ないはずです。実際、Windows + Anaconda Pythonという組み合わせで試行錯誤しましたが、うまくいきませんでした。

このように<を使った入力や出力がうまくいかない処理系を使っている場合であっても、Visual Studio CodePythonデバッグができるような方法を考えました。

手順1. Pythonスクリプトに3行追加

Visual Studio Codeは、以下の形式であれば問題なく引数を受け取ってデバッグ可能です。

$ python solve.py input.txt

そこで、以下のどちらでも同じように入力を受け取れるようにPythonスクリプトを変更します。

$ python solve.py < input.txt
$ python solve.py input.txt

そのためには、Pythonスクリプトに以下の行を入れればOKです(参考: how to take a file as input stream in Python - Stack Overflow )。

import sys


if len(sys.argv) == 2:
    sys.stdin = open(sys.argv[1])

例えば ABC146のB問題ですと、回答は以下のようになります。

import sys


if len(sys.argv) == 2:
    sys.stdin = open(sys.argv[1])

N = int(input())
S = input()
T = ""
for ch in S:
    T += chr(ord('A') + (ord(ch) - ord('A') + N) % 26)
print(T)

手順2. Visual Studio Codeデバッグ実行

事前にPython拡張機能を入れて置く必要があります(多分Pythonで検索して最初に出てくるMicrosoftのものでOK)。

  • Pythonコードを右クリックしてVisual Studio Codeで開きます。

  • 左の方にあるデバッグのアイコンをクリックします。Ctrl + Shift + DでもOKです。

f:id:minus9d:20200723181922p:plain

  • "create a launch.json file." をクリックします。

f:id:minus9d:20200723182544p:plain

  • "Python"を選択します。

f:id:minus9d:20200723182626p:plain

  • "Python File"を選択します。

f:id:minus9d:20200723182700p:plain

  • 自動で作成される雛形に、"args": [/path/to/入力ファイルへのパス] を追加します。パスは、あなたのPythonスクリプトのあるディレクトリを基準とした相対パスでもよいようです。以下に例を示します。
{
    // Use IntelliSense to learn about possible attributes.
    // Hover to view descriptions of existing attributes.
    // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
    "version": "0.2.0",
    "configurations": [
        {
            "name": "Python: Current File",
            "type": "python",
            "request": "launch",
            "program": "${file}",
            "console": "integratedTerminal",
            "args": ["test/sample-1.in"]
        }
    ]
}
  • Pythonスクリプトのタブを選択して、行番号の左側をクリックしてBreakpointをはります。

f:id:minus9d:20200723183947p:plain

  • さっき作った設定が選ばれていることを確認して、再生ボタンを押します。

f:id:minus9d:20200723183414p:plain

  • Debugしましょう!

f:id:minus9d:20200723183548p:plain

ios_base::sync_with_stdio(false); cin.tie(0); の意味

競技プログラミングC++を使うときに、入出力を高速化する目的でおまじないのように書かれる

ios_base::sync_with_stdio(false);
cin.tie(0);

の意味、実はよくわかっていなかったので c++ - Significance of ios_base::sync_with_stdio(false); cin.tie(NULL); - Stack Overflow を主に参考として調べてみました。

sync_with_stdio(false);

C++の標準入出力ストリームがCの入出力と同期しないようにします。代償として、 std::cinscanf() を混ぜたり std::coutprintf() を混ぜたりすると破滅するようになります。

std::cinscanf()を混ぜて入力がおかしくなる例を示します。まず以下のようなテキストファイルを用意します。

100
0 1 2 3 (略) 98 99

これを、以下のようにstd::ios::sync_with_stdio(false); した上で、std::cinscanf()で交互に数値を読み込みます。

#include <iostream>
#include <cstdio>
#include <vector>

int main(void)
{
    // C++の標準入出力ストリームがCの入出力と同期しないようにする
    std::ios::sync_with_stdio(false);

    int N;
    std::cin >> N;

    // 標準入力からcinとscanfで交互に数字を読む
    std::vector<int> arr(N);
    for(int n = 0; n < N; ++n) {
        if (n % 2) std::cin >> arr[n];
        else scanf("%d", &arr[n]);
    }

    // 読み取った結果を表示
    for(int n = 0; n < N; ++n) {
        std::cout << arr[n] << " ";
    }
    std::cout << std::endl;

    return 0;
}

出力例は以下です。正しく数値を読み込めていないことがわかります。

0 0 0 1 0 2 0 3 0 4 0 5 0 6 0 7 0 8 0 9 0 10 0 11 0 12 0 13 0 14 0 15 0 16 0 17 0 18 0 19 0 20 0 21 0 22 0 23 0 24 0 25 0 26 0 27 0 28 0 29 0 30 0 31 0 32 0 33 0 34 0 35 0 36 0 37 0 38 0 39 0 40 0 41 0 42 0 43 0 44 0 45 0 46 0 47 0 48 0 49

コードは https://ideone.com/u6TBWZ で実行を試すことができます。

cin.tie(0)

std::cinstd::coutとの結合を解きます。例えば、std::coutで出力を要求する文字が全部出力される前に、std::cinによる入力待ち状態になることがありえるようになります。

#include <iostream>
int main(void)
{
    std::ios::sync_with_stdio(false);
    std::cin.tie(0);

    int n;
    for(int rep = 0; rep < 3; ++rep) {
        std::cout << "Input number: ";
        std::cin >> n;
        std::cout << n << " is given." << std::endl;
    }

    return 0;
}

実行例は以下です。Input number: が表示される前に数値の入力が要求されてしまいました。

$ ./a.exe
100
Input number: 100 is given.
200
Input number: 200 is given.
300
Input number: 300 is given.

std::cout << "Input number: " << std::flush; とすればこの問題は解決可能です。

Oversized Pancake Choppers の解説

Google Code Jam 2020 Round 1Cの最終問題であるOversized Pancake Choppersの解説です。

この問題はTest Set1から3の3つから構成されます。本番ではTest Set 1のみ解けました。この記事ではTest Set2とTest Set3について解説します。

問題概要

N個のパンケーキが存在する。パンケーキのサイズはA = {A_1, ..., A_N}である。D人の客に、同じサイズのパンケーキを渡さなければいけない。これを達成するために必要な最小のカット数はいくらか。

Test Set 2

制約

  • 1 ≤ N ≤ 300
  • 2 ≤ D ≤ 50

考察

パンケーキは、「等分でカットする」または「まったくカットしない」のどちらかのときに、「得られるスライス数=カット数+1」が成立しカット数を節約できます。例えば、サイズ14のパンケーキからサイズ4のスライスを3個得るには3回カットしなければいけませんが、サイズ12のパンケーキからサイズ4のスライスを3個得るには2回のカットで済みます。

例としてN = 3, D = 4, A = {3, 5, 6}の場合を考えてみます。客に提供するスライスサイズは、サイズ3のパンケーキを1から4に等分した3, 3/2, 3/3, 3/4のどれかか、サイズ5のパンケーキを1から4に等分した5, 5/2, 5/3, 5/4のどれかか、サイズ6のパンケーキを1から4に等分した6, 6/2, 6/3, 6/4のどれかになります。

Test Set 2では、上に上げたスライスの候補すべてについて、以下を調べます。

  • (a) パンケーキ列Aから、サイズsのスライスをD個取れるか
    • 取れないとすると、このスライスサイズは大きすぎるのでダメ
  • (b) パンケーキ列Aから、サイズsのスライスを取るための最小回数

(a)は、パンケーキのサイズA = {A_1, ..., A_N}それぞれをsで割ってあまりを捨てたものの和をとれば求められます。

(b)は、パンケーキ列Aのうち、サイズsでちょうど割り切れるパンケーキから順に、その中でもカット回数が少なくてすむパンケーキから順に貪欲にカットしていけば求められます。例えば、5回カットして6つのスライスが得られるパンケーキより、2回カットして3つのスライスが得られるパンケーキを優先的にカットするほうが、常に得です。

以上で問題が解けました。時間計算量は  O(DN^2) です。

実装

有理数を扱うのは面倒なので、パンケーキのサイズ列 A = {A_1, ..., A_N} を、スライスサイズが整数になるように何倍かしておくと楽です。

私の実装を以下に示します。Python 3だとTLEしてしまったので、PyPy2で通ることを確認しました。GCJはPyPy3がないので不便です…。

#!/usr/bin/env python2
# -*- coding: utf-8 -*-

"""
客に提供するスライスのサイズは、必ず「どれかのパンケーキを等分したもの」になることに着目。
つまり、スライスサイズは、A[i]/1, A[i]/2, ..., A[i]/(D-1) (i = 0..N-1) のどれかになる。
あるスライスサイズについて、全探索。

Python 3だとSet2でTLE。 PyPy2だとSet2までAC
"""

from __future__ import print_function

import sys


# Python 2, 3両対応
if sys.version_info[0] == 2:
    myinput = raw_input
elif sys.version_info[0] == 3:
    myinput = input


def solve():
    N, D = map(int, myinput().split())
    As = list(map(int, myinput().split()))

    # パンケーキは小さい順に並び替えておく
    As.sort()

    ans = 10 ** 10
    for d in range(1, D + 1):
        Bs = [a * d for a in As]

        for a in As:
            # パンケーキaをd-1回カットしてd個に等分割した場合を以下で考える
            # このときの基準サイズは a / d である。
            # 有理数は扱いにくいので、すべてをd倍する。
            # つまり、パンケーキ: B
            #         基準サイズ: a

            # カット可能か調べる
            avail = 0
            for b in Bs:
                avail += b // a
            # カット不能ならcontinue
            if avail < D:
                continue

            # パンケーキBsそれぞれについて、基準サイズaでちょうど割り切れる場合、
            # 何個のパンケーキを取れるかを数える
            just = []
            for b in Bs:
                if b % a == 0:
                    just.append(b // a)

            # お得なパンケーキから順番にとっていく
            # AsやBsはあらかじめ昇順にソートしてあるので、justも昇順にソートされている
            remain_num = D  # あと何個のスライスが必要か
            cut_num = 0  # カット回数
            for j in just:
                if remain_num >= j:
                    remain_num -= j
                    cut_num += j - 1
                else:
                    break

            # 足りない分は残りのパンケーキから取る
            cut_num += remain_num

            ans = min(ans, cut_num)

    print(ans)            


def main():
    T = int(myinput())
    for testcase in range(T):
        print("Case #{}: ".format(testcase+1), end="")
        solve()

main()

Test Set 3

制約

  • 21ケースは 9000 ≤ N ≤ 10000.
  • 残りのケースは 1 ≤ N ≤ 1000.
  • 2 ≤ D ≤ 50

考察

Test Set 2の解法はTLEするのでもうひと工夫必要です。

Test Set 2の解法では、パンケーキのそれぞれについて、A * D種類のスライスサイズで割り切れるかどうかを判定していたのが無駄でした。 サイズaのパンケーキがカット数節約の効果を得るのは、このパンケーキを1, 2, ..., D分割するときだけなので、この場合のみ考えればよいです。

この考察から、以下のような手続きで, keyが「スライスサイズ」、valueが「そのスライスサイズをちょうど取れる数を並べた配列」であるような辞書を作ると良いことがわかります。

    dict = {}
    for a in A[0], A[1], ..., A[N - 1]:
        for d in 1, 2, ..., D:
            # Fraction(a, d) は、a / dを表現する有理数クラス
            dict[Fraction(a, d)].push_back(d)

例としてN = 3, D = 4, A = {3, 5, 6}の場合を考えてみます。辞書は以下のようになります。

keyは約分された有理数で持つ必要があります。各keyについて、サイズsのスライスを取るための最小回数を求める方法は、Test Set 2で書いたとおりです。

残る問題は、パンケーキ列Aから、サイズsのスライスをD個取れるかの確認です。もし上記辞書のkeyのそれぞれについてこの確認をしてしまうと時間計算量はO(DN^2)となり時間が足りません。

そこで、単調性を利用した二分探索を使います。keyの値をソートすると、小さなスライスサイズのときは問題なくスライスD個をとることができて、スライスサイズを大きくしていくと、どこかのタイミングでスライスD個をとれなくなります。その境界、つまりぎりぎりD個のスライスを作れる最大のスライスサイズs_maxを二分探索で探しておきます。

そうすれば、スライスのサイズsがs_max以下であるかどうかをチェックするだけで、スライスをD個取れるかが分かります。

以上で問題が解けました。時間計算量は、二分探索が O(N \log DN)、各keyについてサイズsのスライスを取るための最小回数を求める部分が O(DN) だと思います。

実装

私の実装を以下に示します。これもPython 3だとTLE、PyPy2でACです。

また、Pythonの標準ライブラリで提供されているFractionクラスは遅すぎたので自力実装しています。カスタムクラスを辞書のkeyとして使うために__hash____eq__を実装しました(参考: python - Object of custom type as dictionary key - Stack Overflow )。カスタムクラスをソートするために普通は比較関数を実装すると思いますが、Python 2とPython 3で必要な比較関数が違う(参考: python - Object of custom type as dictionary key - Stack Overflow)のが面倒だったので、分子 / 分母を浮動小数点数でもっておいてそれをkeyにソートするようにしました。大小関係の逆転が起こると嫌だなと思いましたが一応ACしました。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
パンケーキが価値があるのは、ちょうど等分で割り切れるとき。
すべてのパンケーキについて、1等分 .. D等分を試して
(スライスのサイズ, 何個のスライスが取れるか) のタプルを得る

Python 3だとTLE
"""

from __future__ import print_function
from __future__ import division

import array
from bisect import *
from collections import *
import fractions
from fractions import Fraction
import heapq
from itertools import *
import math
import random
import re
import string
import sys

# Python 2, 3両対応
if sys.version_info[0] == 2:
    input = raw_input
    gcd = fractions.gcd
elif sys.version_info[0] == 3:
    gcd = math.gcd


class MyFraction:
    """fractions.Fractionがあまりに遅いので自分で実装"""
    def __init__(self, n, d):
        gcd_value = gcd(n, d)
        self.numerator = n // gcd_value
        self.denominator = d // gcd_value
        self.real = n / d

    # hashのkeyとして使うために__hash__, __eq__が必要

    def __hash__(self):
        hashed_value = hash((self.numerator, self.denominator))
        return hash((self.numerator, self.denominator))

    def __eq__(self, other):
        return (self.numerator, self.denominator) == \
        (other.numerator, other.denominator)

    def __repr__(self):
        return "{} / {} ({})".format(self.numerator, self.denominator, self.real)


def mydiv(n, fraction):
    """floor(n(整数) / fraction)を返す"""
    return (n * fraction.denominator // fraction.numerator)


def solve_set3(N, D, As):
    ans = 10 ** 10
    slice_size_to_slice_num_list = defaultdict(list)
    for a in As:
        for d in range(1, D + 1):

            # fractions.Fraction()は非常に遅い
            # slice_size = Fraction(a, d)

            # そのため、自作クラスを使う
            slice_size = MyFraction(a, d)

            slice_num = d
            slice_size_to_slice_num_list[slice_size].append(slice_num)

    # slice_sizeのうち、D個カットできる境目を二分探索する
    # これにより、後段でO(N)でスライスが可能かをチェックする必要がなくなる
    slice_size_list = list(slice_size_to_slice_num_list.keys())
    slice_size_list.sort(key=lambda obj: obj.real)
    lo = 0
    hi = len(slice_size_list)
    while hi - lo > 1:
        mid = (lo + hi) // 2

        # パンケーキからslice_sizeサイズのスライスを何個作れるか?
        slice_size = slice_size_list[mid]
        cuttable_slice_num = 0
        for a in As:
            cuttable_slice_num += mydiv(a, slice_size)
            if cuttable_slice_num >= D:
                break

        if cuttable_slice_num >= D:
            lo = mid
        else:
            hi = mid

    # カット可能なスライスの範囲のみを探索
    for slice_size in slice_size_list[:(lo + 1)]:
        slice_num_list = slice_size_to_slice_num_list[slice_size]

        # スライスの個数が小さいものから順番に
        slice_num_list.sort()
        cnt = 0
        profit = 0
        for slice_num in slice_num_list:
            cnt += slice_num
            if cnt <= D:
                profit += 1

        ans = min(ans, D - profit)

    return ans         


def solve():
    N, D = map(int, input().split())
    As = list(map(int, input().split()))
    print(solve_set3(N, D, As))


def main():
    T = int(input())
    for testcase in range(T):
        print("Case #{}: ".format(testcase+1), end="")
        solve()


main()