matplotlibでグラフを保存するときのテンプレート


このエントリーをはてなブックマークに追加

サーバにてmatplotlibを使ってグラフを作成し、ファイルに保存することがよくあるのですが、実装するたびにググりまくって非効率なので、私が高確率で使う機能をテンプレ化しました。

さっそくコードを以下に示します。

import matplotlib
import matplotlib.pyplot as plt
import numpy as np


# ウィンドウを開けない環境でもグラフを作成する
# https://stackoverflow.com/questions/39305810/matplotlib-use-required-before-other-imports-clashes-with-pep8-ignore-or-fix
matplotlib.pyplot.switch_backend('Agg')


def plot_test():
    # プロットしたいデータ系列を2つ作成
    x1 = np.linspace(0, 3, 50)
    y1 = np.sin(x1)

    x2 = np.linspace(0, 4, 70)
    y2 = np.cos(x2)

    # 描画先を作成
    # figsize=(8, 6)により、画像保存時のサイズが800x600になる
    # (figsizeを指定しないとデフォルトで640x480)
    fig, ax = plt.subplots(figsize=(8, 6))

    # 折れ線グラフを2系列プロット
    ax.plot(x1, y1, label='sin curve', marker='o', markersize=3)
    ax.plot(x2, y2, label='cos curve', marker='o', markersize=3)

    # グラフタイトル、X軸タイトル、Y軸タイトルを設定
    ax.set_title('sin and cos curves')
    ax.set_xlabel('x axis')
    ax.set_ylabel('y axis')

    # グリッド線を表示
    ax.grid(axis='both')

    # 凡例を表示
    ax.legend()

    # X軸、Y軸の表示範囲を手動で調整(オプション)
    ax.set_xlim(-0.1, 4.1)
    ax.set_ylim(-1.1, 1.1)

    # グラフの周囲の余計な空白を除去
    fig.tight_layout()

    # ファイルに保存
    fig.savefig('plot.png') 

    # 後始末
    # これがないと、グラフを大量に作成したとき
    # "RuntimeWarning: More than 20 figures have been opened."というメッセージが出る
    # https://stackoverflow.com/questions/45933886/python-plt-close-or-clear-figure-does-not-work
    plt.close(fig)


if __name__ == '__main__':
    plot_test()
        

出力されるplot.pngは以下です。

f:id:minus9d:20210302231718p:plain

大体の説明はコード中に書いてしまいました。以下は補足です。

matplotlib.pyplot.switch_backend('Agg')

上の1行は、サーバのようにウィンドウを開けないような環境でmatplotlibを使うときの定型文です(意味はよくわかってません)。以前は

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

という定型文を書いていたのですが、 python - matplotlib.use required before other imports clashes with pep8. Ignore or fix? - Stack Overflow によると、この書き方でよいようです(注:まだ試せていません)。

参考URL