TensorFlow / Kerasにおけるlogitsの意味

TensorFlowやKerasを使うと遭遇する logits という用語、ざっと検索してもすぐに意味が出てこなかったので書いておきます。なお、この用語は複数の意味を持つ単語なので注意願います。この記事ではあくまで TensorFlow / Keras でのlogitsの意味を説明しています。

logitsの定義

scikit-learnとTensorFlowによる実践機械学習 (Amazonアフィリエイトリンク) のp267を引用すると、logitsとは、「ソフトマックス活性化関数に通す前のニューラルネットワークの出力」です。

概念図

画像を3クラスに分類する分類器をニューラルネットで設計したときの典型的な例を以下に示します。ニューラルネットに1枚画像を入力し、  \left[4.0, 2.0, 1.0  \right] ^T という出力を得たとします。これはソフトマックス活性化関数に通す前の値であり、logitsと呼ばれるものです。

f:id:minus9d:20201025192350p:plain

このlogitsを、ソフトマックス活性化関数に与えて得られる出力は、確率を表すベクトルです。ソフトマックス活性化関数の定義に従って計算すると、 \left[0.844, 0.114, 0.042  \right] ^T が得られます。確率なので足すと1.0になっています。0.844の部分の計算はPythonだと以下のようになります。

>>> import math
>>> print(math.exp(4.0) / (math.exp(4.0) + math.exp(2.0) + math.exp(1.0)))
0.8437947344813395

この確率を表すベクトルと、正解ベクトル  \left[1.0, 0.0, 0.0  \right] ^T との間の誤差は、通常、2つの確率分布の距離を計算する関数である交差エントロピー誤差を用いて計算されます。交差エントロピー誤差の定義に従って計算すると、 0.170 になります。この計算はPythonだと以下のようになります。

>>> import math
>>> print((- 1.0 * math.log(0.8437947) - 0.0 * math.log(0.11419519) - 0.0 * mat
h.log(0.04201007)))
0.1698460604208926

tensorflowでの関数

さきほどの図と同じ計算をtensorflowの関数を使って行うと以下のようになります。

import tensorflow as tf


logits = [[4.0, 2.0, 1.0]]
labels = [[1.0, 0.0, 0.0]]

probabilites = tf.nn.softmax(logits)
print(probabilites)  # [[0.8437947  0.11419519 0.04201007]]

cross_entropy_loss = tf.losses.categorical_crossentropy(y_true=labels, y_pred=probabilites)
print(cross_entropy_loss)  # [0.16984598]

softmaxを経由せず、いきなり交差エントロピー誤差を計算する方法もあります。1つ目は、softmax_cross_entropy_with_logits()を使う方法です。

cross_entropy_loss1 = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
print(cross_entropy_loss1)  # [0.16984604]

2つ目は、categorical_crossentropy()の引数from_logitsをTrueにする方法です。

cross_entropy_loss2 = tf.losses.categorical_crossentropy(y_true=labels, y_pred=logits, from_logits=True)
print(cross_entropy_loss2)  # [0.16984604]

logitsと確率を間違えて使ってしまうと正しい値を計算できないので気をつけましょう。

Pythonのexceptionをprint()して何も出ない原因

Python 3にて以下のように例外オブジェクトをprint()するコード

try:
    raise ValueError
except Exception as e:
    print(e)

を実行しても、以下のように何も出力されません。




結論からいうと、上記のコードを以下のように修正すると

try:
    raise ValueError('value error has occured')
except Exception as e:
    print(e)

以下の様に出力されるようになります。

value error has occured

これはなぜかを説明していきます。

Pythonのraise文で投げることができるのは、例外クラスか、例外インスタンスのどちらかです( 参考 )。先程の例外クラスを投げた例

    raise ValueError

は、以下

    raise ValueError()

のように、引数無しで生成した例外インスタンスを投げたのと同じになります。

例外インスタンスをつくるときの引数には何を与えればよいでしょうか? 参考

raise Exception('spam', 'eggs')

という例がある通り、引数には任意のオブジェクトを任意の個数与えることができるようです。ただ、普通は、例外が起こった原因を究明するために役立つ情報を文字列で与えることが多いと思います。

ここで設定したオブジェクトは、except節で'as'を使って指定した変数の.argsの中に入っています。また、この変数を文字列として扱うと、.args を参照しなくても引数を直接印字できるようになっています。よって、下の例ではprint(e.args)としてもprint(e)としても同じ結果になります。

try:
    raise ValueError('spam', 'eggs')
except Exception as e:
    print(e.args)  # ('spam', 'eggs') と表示される
    print(e)  # ('spam', 'eggs') と表示される

ここまでの説明により、

try:
    raise ValueError
except Exception as e:
    print(e)

で何も表示されなかった理由がわかりました。

Visual Studio Codeの基本設定

共有設定を使う場合

2020年3月頃からプレビュー機能として提供されている、設定を共有する機能を有効にすることで、複数のマシンでVSCodeの設定を共有できます。参考: 「Visual Studio Code」が設定の同期に対応、Insider版でテスト中 - 窓の杜

  • 左下の歯車をクリック
  • Turn on Setting Sync...をクリック
  • Turn Onをクリック
  • 共有したいものを選んで Sign in & onをクリック
  • Microsoft / GItHub の好きな方でログイン

一から設定する場合

共有設定を使えず一から設定をやり直す必要があるときは、以下を順番に設定。

Settingsを設定

Visual Studio本体の設定はSettingsから行います。File -> Preferences -> Settingsからも開けるけど、Ctrl + , が便利。

拡張機能

まだまだ研究中です。

Snnipetを設定

競技プログラミングで役立ちます(まだほとんど何も登録してませんが)。File -> Preferences -> User Snippetsから好きな言語を選び、例えば以下のようなjsonを作成します。

{
    "is_prime": {
        "prefix": "is_prime",
        "body": [
            "bool is_prime(const ll n){",
            "    if (n <= 1){",
            "         return false;",
            "    }",
            "   for(ll i = 2; i*i <= n; ++i){",
            "       if (n % i == 0){",
            "           return false;",
            "       }",
            "   }",
            "   return true;",
            "}",
        ],
        "description": "judge n is prime or not"
    }
}
エディタで`is_prime`と打ってTabを押せば上記のスニペットを挿入できます。

## 参考URLs

* [https://github.com/microsoft/vscode-tips-and-tricks:title]
    * 設定だけでなくショートカットキーの一覧が豊富

VS CodeでPython 3の競プロコードをデバッグ実行

AtCoderCodeforcesなどの競技プログラミングサイトでは、通常、標準入力から入力を受け取ります。例えばPythonで回答する場合、あなたの回答(例としてsolve.pyとします)は、入力が書かれたテキストファイルinput.txt

$ python solve.py < input.txt

のように与えられたとき、正しく出力することが求められます。

Visual Studio Codeでは<を使った入力や出力はdebugger/runtime specificということになっています (https://code.visualstudio.com/docs/editor/debugging#_redirect-inputoutput-tofrom-the-debug-target) 。私の理解が正しければ、Pythonを実行する処理系が<をうまく扱えない場合、Visual Studio Code<を使う諦めざるを得ないはずです。実際、Windows + Anaconda Pythonという組み合わせで試行錯誤しましたが、うまくいきませんでした。

このように<を使った入力や出力がうまくいかない処理系を使っている場合であっても、Visual Studio CodePythonデバッグができるような方法を考えました。

手順1. Pythonスクリプトに3行追加

Visual Studio Codeは、以下の形式であれば問題なく引数を受け取ってデバッグ可能です。

$ python solve.py input.txt

そこで、以下のどちらでも同じように入力を受け取れるようにPythonスクリプトを変更します。

$ python solve.py < input.txt
$ python solve.py input.txt

そのためには、Pythonスクリプトに以下の行を入れればOKです(参考: how to take a file as input stream in Python - Stack Overflow )。

import sys


if len(sys.argv) == 2:
    sys.stdin = open(sys.argv[1])

例えば ABC146のB問題ですと、回答は以下のようになります。

import sys


if len(sys.argv) == 2:
    sys.stdin = open(sys.argv[1])

N = int(input())
S = input()
T = ""
for ch in S:
    T += chr(ord('A') + (ord(ch) - ord('A') + N) % 26)
print(T)

手順2. Visual Studio Codeデバッグ実行

事前にPython拡張機能を入れて置く必要があります(多分Pythonで検索して最初に出てくるMicrosoftのものでOK)。

  • Pythonコードを右クリックしてVisual Studio Codeで開きます。

  • 左の方にあるデバッグのアイコンをクリックします。Ctrl + Shift + DでもOKです。

f:id:minus9d:20200723181922p:plain

  • "create a launch.json file." をクリックします。

f:id:minus9d:20200723182544p:plain

  • "Python"を選択します。

f:id:minus9d:20200723182626p:plain

  • "Python File"を選択します。

f:id:minus9d:20200723182700p:plain

  • 自動で作成される雛形に、"args": [/path/to/入力ファイルへのパス] を追加します。パスは、あなたのPythonスクリプトのあるディレクトリを基準とした相対パスでもよいようです。以下に例を示します。
{
    // Use IntelliSense to learn about possible attributes.
    // Hover to view descriptions of existing attributes.
    // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
    "version": "0.2.0",
    "configurations": [
        {
            "name": "Python: Current File",
            "type": "python",
            "request": "launch",
            "program": "${file}",
            "console": "integratedTerminal",
            "args": ["test/sample-1.in"]
        }
    ]
}
  • Pythonスクリプトのタブを選択して、行番号の左側をクリックしてBreakpointをはります。

f:id:minus9d:20200723183947p:plain

  • さっき作った設定が選ばれていることを確認して、再生ボタンを押します。

f:id:minus9d:20200723183414p:plain

  • Debugしましょう!

f:id:minus9d:20200723183548p:plain

ios_base::sync_with_stdio(false); cin.tie(0); の意味

競技プログラミングC++を使うときに、入出力を高速化する目的でおまじないのように書かれる

ios_base::sync_with_stdio(false);
cin.tie(0);

の意味、実はよくわかっていなかったので c++ - Significance of ios_base::sync_with_stdio(false); cin.tie(NULL); - Stack Overflow を主に参考として調べてみました。

sync_with_stdio(false);

C++の標準入出力ストリームがCの入出力と同期しないようにします。代償として、 std::cinscanf() を混ぜたり std::coutprintf() を混ぜたりすると破滅するようになります。

std::cinscanf()を混ぜて入力がおかしくなる例を示します。まず以下のようなテキストファイルを用意します。

100
0 1 2 3 (略) 98 99

これを、以下のようにstd::ios::sync_with_stdio(false); した上で、std::cinscanf()で交互に数値を読み込みます。

#include <iostream>
#include <cstdio>
#include <vector>

int main(void)
{
    // C++の標準入出力ストリームがCの入出力と同期しないようにする
    std::ios::sync_with_stdio(false);

    int N;
    std::cin >> N;

    // 標準入力からcinとscanfで交互に数字を読む
    std::vector<int> arr(N);
    for(int n = 0; n < N; ++n) {
        if (n % 2) std::cin >> arr[n];
        else scanf("%d", &arr[n]);
    }

    // 読み取った結果を表示
    for(int n = 0; n < N; ++n) {
        std::cout << arr[n] << " ";
    }
    std::cout << std::endl;

    return 0;
}

出力例は以下です。正しく数値を読み込めていないことがわかります。

0 0 0 1 0 2 0 3 0 4 0 5 0 6 0 7 0 8 0 9 0 10 0 11 0 12 0 13 0 14 0 15 0 16 0 17 0 18 0 19 0 20 0 21 0 22 0 23 0 24 0 25 0 26 0 27 0 28 0 29 0 30 0 31 0 32 0 33 0 34 0 35 0 36 0 37 0 38 0 39 0 40 0 41 0 42 0 43 0 44 0 45 0 46 0 47 0 48 0 49

コードは https://ideone.com/u6TBWZ で実行を試すことができます。

cin.tie(0)

std::cinstd::coutとの結合を解きます。例えば、std::coutで出力を要求する文字が全部出力される前に、std::cinによる入力待ち状態になることがありえるようになります。

#include <iostream>
int main(void)
{
    std::ios::sync_with_stdio(false);
    std::cin.tie(0);

    int n;
    for(int rep = 0; rep < 3; ++rep) {
        std::cout << "Input number: ";
        std::cin >> n;
        std::cout << n << " is given." << std::endl;
    }

    return 0;
}

実行例は以下です。Input number: が表示される前に数値の入力が要求されてしまいました。

$ ./a.exe
100
Input number: 100 is given.
200
Input number: 200 is given.
300
Input number: 300 is given.

std::cout << "Input number: " << std::flush; とすればこの問題は解決可能です。

Oversized Pancake Choppers の解説

Google Code Jam 2020 Round 1Cの最終問題であるOversized Pancake Choppersの解説です。

この問題はTest Set1から3の3つから構成されます。本番ではTest Set 1のみ解けました。この記事ではTest Set2とTest Set3について解説します。

問題概要

N個のパンケーキが存在する。パンケーキのサイズはA = {A_1, ..., A_N}である。D人の客に、同じサイズのパンケーキを渡さなければいけない。これを達成するために必要な最小のカット数はいくらか。

Test Set 2

制約

  • 1 ≤ N ≤ 300
  • 2 ≤ D ≤ 50

考察

パンケーキは、「等分でカットする」または「まったくカットしない」のどちらかのときに、「得られるスライス数=カット数+1」が成立しカット数を節約できます。例えば、サイズ14のパンケーキからサイズ4のスライスを3個得るには3回カットしなければいけませんが、サイズ12のパンケーキからサイズ4のスライスを3個得るには2回のカットで済みます。

例としてN = 3, D = 4, A = {3, 5, 6}の場合を考えてみます。客に提供するスライスサイズは、サイズ3のパンケーキを1から4に等分した3, 3/2, 3/3, 3/4のどれかか、サイズ5のパンケーキを1から4に等分した5, 5/2, 5/3, 5/4のどれかか、サイズ6のパンケーキを1から4に等分した6, 6/2, 6/3, 6/4のどれかになります。

Test Set 2では、上に上げたスライスの候補すべてについて、以下を調べます。

  • (a) パンケーキ列Aから、サイズsのスライスをD個取れるか
    • 取れないとすると、このスライスサイズは大きすぎるのでダメ
  • (b) パンケーキ列Aから、サイズsのスライスを取るための最小回数

(a)は、パンケーキのサイズA = {A_1, ..., A_N}それぞれをsで割ってあまりを捨てたものの和をとれば求められます。

(b)は、パンケーキ列Aのうち、サイズsでちょうど割り切れるパンケーキから順に、その中でもカット回数が少なくてすむパンケーキから順に貪欲にカットしていけば求められます。例えば、5回カットして6つのスライスが得られるパンケーキより、2回カットして3つのスライスが得られるパンケーキを優先的にカットするほうが、常に得です。

以上で問題が解けました。時間計算量は  O(DN^2) です。

実装

有理数を扱うのは面倒なので、パンケーキのサイズ列 A = {A_1, ..., A_N} を、スライスサイズが整数になるように何倍かしておくと楽です。

私の実装を以下に示します。Python 3だとTLEしてしまったので、PyPy2で通ることを確認しました。GCJはPyPy3がないので不便です…。

#!/usr/bin/env python2
# -*- coding: utf-8 -*-

"""
客に提供するスライスのサイズは、必ず「どれかのパンケーキを等分したもの」になることに着目。
つまり、スライスサイズは、A[i]/1, A[i]/2, ..., A[i]/(D-1) (i = 0..N-1) のどれかになる。
あるスライスサイズについて、全探索。

Python 3だとSet2でTLE。 PyPy2だとSet2までAC
"""

from __future__ import print_function

import sys


# Python 2, 3両対応
if sys.version_info[0] == 2:
    myinput = raw_input
elif sys.version_info[0] == 3:
    myinput = input


def solve():
    N, D = map(int, myinput().split())
    As = list(map(int, myinput().split()))

    # パンケーキは小さい順に並び替えておく
    As.sort()

    ans = 10 ** 10
    for d in range(1, D + 1):
        Bs = [a * d for a in As]

        for a in As:
            # パンケーキaをd-1回カットしてd個に等分割した場合を以下で考える
            # このときの基準サイズは a / d である。
            # 有理数は扱いにくいので、すべてをd倍する。
            # つまり、パンケーキ: B
            #         基準サイズ: a

            # カット可能か調べる
            avail = 0
            for b in Bs:
                avail += b // a
            # カット不能ならcontinue
            if avail < D:
                continue

            # パンケーキBsそれぞれについて、基準サイズaでちょうど割り切れる場合、
            # 何個のパンケーキを取れるかを数える
            just = []
            for b in Bs:
                if b % a == 0:
                    just.append(b // a)

            # お得なパンケーキから順番にとっていく
            # AsやBsはあらかじめ昇順にソートしてあるので、justも昇順にソートされている
            remain_num = D  # あと何個のスライスが必要か
            cut_num = 0  # カット回数
            for j in just:
                if remain_num >= j:
                    remain_num -= j
                    cut_num += j - 1
                else:
                    break

            # 足りない分は残りのパンケーキから取る
            cut_num += remain_num

            ans = min(ans, cut_num)

    print(ans)            


def main():
    T = int(myinput())
    for testcase in range(T):
        print("Case #{}: ".format(testcase+1), end="")
        solve()

main()

Test Set 3

制約

  • 21ケースは 9000 ≤ N ≤ 10000.
  • 残りのケースは 1 ≤ N ≤ 1000.
  • 2 ≤ D ≤ 50

考察

Test Set 2の解法はTLEするのでもうひと工夫必要です。

Test Set 2の解法では、パンケーキのそれぞれについて、A * D種類のスライスサイズで割り切れるかどうかを判定していたのが無駄でした。 サイズaのパンケーキがカット数節約の効果を得るのは、このパンケーキを1, 2, ..., D分割するときだけなので、この場合のみ考えればよいです。

この考察から、以下のような手続きで, keyが「スライスサイズ」、valueが「そのスライスサイズをちょうど取れる数を並べた配列」であるような辞書を作ると良いことがわかります。

    dict = {}
    for a in A[0], A[1], ..., A[N - 1]:
        for d in 1, 2, ..., D:
            # Fraction(a, d) は、a / dを表現する有理数クラス
            dict[Fraction(a, d)].push_back(d)

例としてN = 3, D = 4, A = {3, 5, 6}の場合を考えてみます。辞書は以下のようになります。

keyは約分された有理数で持つ必要があります。各keyについて、サイズsのスライスを取るための最小回数を求める方法は、Test Set 2で書いたとおりです。

残る問題は、パンケーキ列Aから、サイズsのスライスをD個取れるかの確認です。もし上記辞書のkeyのそれぞれについてこの確認をしてしまうと時間計算量はO(DN^2)となり時間が足りません。

そこで、単調性を利用した二分探索を使います。keyの値をソートすると、小さなスライスサイズのときは問題なくスライスD個をとることができて、スライスサイズを大きくしていくと、どこかのタイミングでスライスD個をとれなくなります。その境界、つまりぎりぎりD個のスライスを作れる最大のスライスサイズs_maxを二分探索で探しておきます。

そうすれば、スライスのサイズsがs_max以下であるかどうかをチェックするだけで、スライスをD個取れるかが分かります。

以上で問題が解けました。時間計算量は、二分探索が O(N \log DN)、各keyについてサイズsのスライスを取るための最小回数を求める部分が O(DN) だと思います。

実装

私の実装を以下に示します。これもPython 3だとTLE、PyPy2でACです。

また、Pythonの標準ライブラリで提供されているFractionクラスは遅すぎたので自力実装しています。カスタムクラスを辞書のkeyとして使うために__hash____eq__を実装しました(参考: python - Object of custom type as dictionary key - Stack Overflow )。カスタムクラスをソートするために普通は比較関数を実装すると思いますが、Python 2とPython 3で必要な比較関数が違う(参考: python - Object of custom type as dictionary key - Stack Overflow)のが面倒だったので、分子 / 分母を浮動小数点数でもっておいてそれをkeyにソートするようにしました。大小関係の逆転が起こると嫌だなと思いましたが一応ACしました。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
パンケーキが価値があるのは、ちょうど等分で割り切れるとき。
すべてのパンケーキについて、1等分 .. D等分を試して
(スライスのサイズ, 何個のスライスが取れるか) のタプルを得る

Python 3だとTLE
"""

from __future__ import print_function
from __future__ import division

import array
from bisect import *
from collections import *
import fractions
from fractions import Fraction
import heapq
from itertools import *
import math
import random
import re
import string
import sys

# Python 2, 3両対応
if sys.version_info[0] == 2:
    input = raw_input
    gcd = fractions.gcd
elif sys.version_info[0] == 3:
    gcd = math.gcd


class MyFraction:
    """fractions.Fractionがあまりに遅いので自分で実装"""
    def __init__(self, n, d):
        gcd_value = gcd(n, d)
        self.numerator = n // gcd_value
        self.denominator = d // gcd_value
        self.real = n / d

    # hashのkeyとして使うために__hash__, __eq__が必要

    def __hash__(self):
        hashed_value = hash((self.numerator, self.denominator))
        return hash((self.numerator, self.denominator))

    def __eq__(self, other):
        return (self.numerator, self.denominator) == \
        (other.numerator, other.denominator)

    def __repr__(self):
        return "{} / {} ({})".format(self.numerator, self.denominator, self.real)


def mydiv(n, fraction):
    """floor(n(整数) / fraction)を返す"""
    return (n * fraction.denominator // fraction.numerator)


def solve_set3(N, D, As):
    ans = 10 ** 10
    slice_size_to_slice_num_list = defaultdict(list)
    for a in As:
        for d in range(1, D + 1):

            # fractions.Fraction()は非常に遅い
            # slice_size = Fraction(a, d)

            # そのため、自作クラスを使う
            slice_size = MyFraction(a, d)

            slice_num = d
            slice_size_to_slice_num_list[slice_size].append(slice_num)

    # slice_sizeのうち、D個カットできる境目を二分探索する
    # これにより、後段でO(N)でスライスが可能かをチェックする必要がなくなる
    slice_size_list = list(slice_size_to_slice_num_list.keys())
    slice_size_list.sort(key=lambda obj: obj.real)
    lo = 0
    hi = len(slice_size_list)
    while hi - lo > 1:
        mid = (lo + hi) // 2

        # パンケーキからslice_sizeサイズのスライスを何個作れるか?
        slice_size = slice_size_list[mid]
        cuttable_slice_num = 0
        for a in As:
            cuttable_slice_num += mydiv(a, slice_size)
            if cuttable_slice_num >= D:
                break

        if cuttable_slice_num >= D:
            lo = mid
        else:
            hi = mid

    # カット可能なスライスの範囲のみを探索
    for slice_size in slice_size_list[:(lo + 1)]:
        slice_num_list = slice_size_to_slice_num_list[slice_size]

        # スライスの個数が小さいものから順番に
        slice_num_list.sort()
        cnt = 0
        profit = 0
        for slice_num in slice_num_list:
            cnt += slice_num
            if cnt <= D:
                profit += 1

        ans = min(ans, D - profit)

    return ans         


def solve():
    N, D = map(int, input().split())
    As = list(map(int, input().split()))
    print(solve_set3(N, D, As))


def main():
    T = int(input())
    for testcase in range(T):
        print("Case #{}: ".format(testcase+1), end="")
        solve()


main()

Codeforces Round #641 Orac and Mediansの解説

約1年ぶりにCodeforcesに出場しました。Div. 2の4問目、Orac and Medians が解けそうで解けませんでした。終了後、解説を読みながら考えをまとめました。

問題概要

数列 a_1, a_2, ..., a_Nが与えられる。この数列の任意区間を選び、その区間のすべての値を、その区間の中央値で置換することを好きな回数だけ行う。数列すべての値を値 Kにすることはできるか? できる場合は"yes", できない場合は"no"を出力。

ここで、数列sの中央値は、要素を小さい順に並べたときの ⌊|s|+1/2⌋ 番目の要素と定義。例えば数列{1, 7, 5, 8, 4}の中央値は、3番目に小さい5。数列{1, 7, 5, 8}の中央値は、2番めに小さい5。

考察

Kより小さい値、K、Kより大きい値の3カテゴリのみ考えればよい。以下では、Kより小さい値を0, Kを1, Kを2と置換した数列 b_1, b_2, ..., b_Nを考える。

以下のケースはすぐに分かる。

  • 数列 b_iに1個も1が存在しない場合は明らかに"no"。
  • 数列 b_iの長さが1の場合、 b_1が1の場合のみ"yes"。

今後、数列の長さは2以上、かつ少なくとも1個は1が存在するものとする。

以下のケースも少し考えると分かる。

  • 1と1が隣り合う箇所がある場合は"yes"。
    • "? 1 1" または "1 1 ?" どちらのパターンでも中央値は1なので、"1 1 1" と1が増殖する。これを続けると数列はすべて1になる。
  • 1と2が隣り合う箇所がある場合は"yes"。
    • "1 2" または "2 1" どちらのパターンでも中央値は1なので、"1 1"となる。あとは↑のパターンに帰着。
  • 2と2が隣り合う箇所がある場合は"yes"。
    • "2 2"の前後にある"0"をすべて"2"に変えていくと、いつかは"1"に接触する。
    • すると"1 2"または"2 1"のパターンが発生するので、"yes"になる。

さらに考えると、以下のパターンも成り立つことがわかる。

  • "1 0 1"がある場合は"yes"。
    • "1 0 1" の中央値は1なので、"1 1 1"を作れる。"1 1"のパターンを作れたので"yes"。
  • "1 0 2" または "2 0 1" がある場合は"yes"。
    • "1 0 2" または "2 0 1" どちらのパターンでも中央値は1なので、"1 1 1"を作れる。"1 1"のパターンを作れたので"yes"。
  • "2 0 2" がある場合は"yes"。
    • "2 0 2"の中央値は2なので、"2 2 2"を作れる。
    • "2 2"のパターンを作れたので"yes"。

解説によると"yes"のパターンとなるのは実は上記までで、上に書いたパターンのどれも持たない数列はすべて"no"になる。これを直感的に感じ取るのは私には難しかった。

上に書いたパターンにどれにも当てはまらないということは、数列の1や2の間には、少なくとも2つの0があるということ。例えば以下。

  • "0 1 0 0 1 0 0 2"
  • "2 0 0 2 0 0 1 0 0 2"

0の影響力は強くて、0を1や2に変えるには、1や2で挟む必要がある。しかし上記のようなパターンだと、1や2で挟めない。よって"no"となる。

回答例 (C++)

#include <algorithm>
#include <iostream>
#include <vector>
using namespace std;
#define REP(i,n) for(int i = 0; i < (int)(n); ++i)
#define FOR(i,a,b) for(int i = (a); i < (int)(b); ++i)
#define ALL(c) (c).begin(), (c).end()
#define SIZE(v) ((int)v.size())
#define pb push_back
#define mp make_pair
#define mt make_tuple

string solve(int N, int K, vector<int>& As) {
    // Kが1個もなければno
    if (find(ALL(As), K) == As.end()) {
        return "no";
    }

    // 以下では少なくとも1個のKがあるとしてよい

    // 配列の長さ1
    if (N == 1) {
        return "yes";
    }

    // 以下では配列の長さが2以上としてよい

    // 隣り合う数同士にK以上があるならyes
    REP(i, N - 1) {
        if (As[i] >= K && As[i + 1] >= K) {
            return "yes";
        }
    }
    // 1個離れた位置にK以上があるならyes
    REP(i, N - 2) {
        if (As[i] >= K && As[i + 2] >= K) {
            return "yes";
        }
    }

    return "no";
}

int main(void)
{
    cin.sync_with_stdio(false);
    int T; cin >> T;
    REP(t, T) {
        int N; int K;
        cin >> N >> K;
        vector<int> As(N);
        REP(n, N) cin >> As[n];
        cout << solve(N, K, As) << endl;
    }

    return 0;
}