ngng628's Library

This documentation is automatically generated by online-judge-tools/verification-helper

:heavy_check_mark: src/nglib/data_structure/wavelet_matrix.cr

Depends on

Required by

Verified with

Code

require "./succinct_bit_vector.cr"

module NgLib
  # **非負整数列** $A$ に対して、順序に関する様々なクエリに答えます。
  #
  # ジェネリクス T は `#[]` などでの返却値の型を指定するものであって、
  # 数列の値は非負整数でなければならないことに注意してください。
  #
  # 基本的には `CompressedWaveletMatrix` の方が高速です。
  class WaveletMatrix(T)
    include Indexable(T)

    @values : Array(UInt64)
    @n_bits : Int32
    @bit_vectors : Array(NgLib::SuccinctBitVector)

    delegate size, to: @values

    # 非負整数列 $A$ を `values` によって構築します。
    #
    # ```
    # WaveletMatrix.new([1, 3, 2, 5])
    # ```
    def initialize(values : Array(T))
      n = values.size
      @values = values.map(&.to_u64)
      @n_bits = log2_floor({1_u64, @values.max}.max) + 1
      @bit_vectors = Array.new(@n_bits) { NgLib::SuccinctBitVector.new(n) }

      cur = @values.clone
      nxt = Array.new(n) { 0_u64 }
      (@n_bits - 1).downto(0) do |height|
        n.times do |i|
          @bit_vectors[height].set(i) if cur[i].bit(height) == 1
        end

        @bit_vectors[height].build

        indices = [0, @bit_vectors[height].count_zeros]
        n.times do |i|
          bit = @bit_vectors[height][i]
          nxt[indices[bit]] = cur[i]
          indices[bit] += 1
        end

        cur, nxt = nxt, cur
      end
    end

    # 長さ $n$ の非負整数列 $A$ を構築します。
    #
    # ```
    # WaveletMatrix.new(10) { |i| (5 - i) ** 2 }
    # ```
    def self.new(n : Int, & : Int32 -> T)
      WaveletMatrix.new(Array.new(n) { |i| yield i })
    end

    # $A$ の `index` 番目の要素を取得します。
    #
    # ```
    # wm = WaveletMatrix.new([1, 3, 2, 5])
    # wm.unsafe_fetch(0) # => 1
    # wm.unsafe_fetch(1) # => 3
    # wm.unsafe_fetch(2) # => 2
    # wm.unsafe_fetch(3) # => 5
    # ```
    def unsafe_fetch(index : Int)
      @values.unsafe_fetch(index)
    end

    # `kth` 番目に小さい値を返します。
    #
    # ```
    # wm = WaveletMatrix.new([1, 3, 2, 5])
    # wm.kth_smallest(0) # => 1
    # wm.kth_smallest(1) # => 2
    # wm.kth_smallest(2) # => 3
    # wm.kth_smallest(3) # => 5
    # ```
    def kth_smallest(kth : Int32)
      kth_smallest(.., kth)
    end

    # `range` の表す区間に含まれる要素のうち、`kth` 番目に小さい値を返します。
    #
    # ```
    # wm = WaveletMatrix.new([1, 3, 2, 5])
    # wm.kth_smallest(1..2, 0) # => 2
    # wm.kth_smallest(1..2, 1) # => 3
    # ```
    def kth_smallest(range : Range(Int?, Int?), kth : Int)
      l = (range.begin || 0)
      r = (range.end || size) + (range.exclusive? || range.end.nil? ? 0 : 1)

      ret = T.zero
      (@n_bits - 1).downto(0) do |height|
        lzeros, rzeros = succ0(l, r, height)

        if kth < rzeros - lzeros
          l, r = lzeros, rzeros
        else
          kth -= rzeros - lzeros
          ret |= T.zero.succ << height
          l += @bit_vectors[height].count_zeros - lzeros
          r += @bit_vectors[height].count_zeros - rzeros
        end
      end

      ret
    end

    # `kth` 番目に大きい値を返します。
    #
    # ```
    # wm = WaveletMatrix.new([1, 3, 2, 5])
    # wm.kth_smallest(0) # => 5
    # wm.kth_smallest(1) # => 3
    # wm.kth_smallest(2) # => 2
    # wm.kth_smallest(3) # => 1
    # ```
    def kth_largest(kth : Int32)
      kth_largest(.., kth)
    end

    # `range` の表す区間に含まれる要素のうち、`kth` 番目に大きい値を返します。
    #
    # ```
    # wm = WaveletMatrix.new([1, 3, 2, 5])
    # wm.kth_largest(1..2, 0) # => 3
    # wm.kth_largest(1..2, 1) # => 2
    # ```
    def kth_largest(range : Range(Int?, Int?), kth : Int)
      kth_smallest(range, range.size - kth - 1)
    end

    # `item` の個数を返します。
    #
    # 計算量は対数オーダーです。
    def count(item : T)
      count(.., item)
    end

    # `range` の表す区間に含まれる要素の中で `item` の個数を返します。
    def count(range : Range(Int?, Int?), item : T)
      count(range, item..item)
    end

    # `range` の表す区間に含まれる要素の中で `bound` が表す範囲の値の個数を返します。
    def count(range : Range(Int?, Int?), bound : Range(T?, T?)) : Int32
      lower_bound = (bound.begin || 0)
      upper_bound = (bound.end || T::MAX) + (bound.exclusive? || bound.end.nil? ? 0 : 1)
      count_less_eq(range, upper_bound) - count_less_eq(range, lower_bound)
    end

    # `range` の表す区間に含まれる要素のうち、`upper_bound` **未満** の値の最大値を返します。
    #
    # 存在しない場合は `nil` を返します。
    def prev_value(range : Range(Int?, Int?), upper_bound) : T?
      cnt = count_less_eq(range, upper_bound)
      cnt == 0 ? nil : kth_smallest(range, cnt - 1)
    end

    # `range` の表す区間に含まれる要素のうち、`lower_bound` **以上** の値の最小値を返します。
    #
    # 存在しない場合は `nil` を返します。
    def next_value(range : Range(Int?, Int?), lower_bound) : T?
      l = (range.begin || 0)
      r = (range.end || size) + (range.exclusive? || range.end.nil? ? 0 : 1)
      cnt = count_less_eq(range, lower_bound)
      cnt == (l...r).size ? nil : kth_smallest(range, cnt)
    end

    # `range` の表す区間に含まれる要素のうち、`upper_bound` **未満** の値の個数を返します。
    #
    # 存在しない場合は `nil` を返します。
    def count_less_eq(range : Range(Int?, Int?), upper_bound) : Int32
      l = (range.begin || 0)
      r = (range.end || size) + (range.exclusive? || range.end.nil? ? 0 : 1)

      return (l...r).size.to_i if upper_bound >= (T.zero.succ << @n_bits)

      ret = 0
      (@n_bits - 1).downto(0) do |height|
        lzeros, rzeros = succ0(l, r, height)
        if upper_bound.bit(height) == 1
          ret += rzeros - lzeros
          l += @bit_vectors[height].count_zeros - lzeros
          r += @bit_vectors[height].count_zeros - rzeros
        else
          l, r = lzeros, rzeros
        end
      end

      ret
    end

    @[AlwaysInline]
    private def log2_floor(n : UInt64) : Int32
      log2_floor = 63 - n.leading_zeros_count
      log2_floor + ((n & n - 1) == 0 ? 0 : 1)
    end

    @[AlwaysInline]
    private def succ0(left : Int, right : Int, height : Int)
      lzeros = left <= 0 ? 0 : @bit_vectors[height].count_zeros(0...Math.min(left, size))
      rzeros = right <= 0 ? 0 : @bit_vectors[height].count_zeros(0...Math.min(right, size))
      {lzeros, rzeros}
    end
  end

  class CompressedWaveletMatrix(T)
    include Indexable(T)

    @wm : WaveletMatrix(Int32)
    @uniqued : Array(T)

    delegate size, to: @wm

    def initialize(values : Array(T))
      @uniqued = values.sort.uniq!
      @wm = WaveletMatrix.new(values.map { |elem| get_index(elem) })
    end

    def self.new(n : Int, & : Int32 -> T)
      CompressedWaveletMatrix.new(Array.new(n) { |i| yield i })
    end

    def unsafe_fetch(index : Int)
      @uniqued[@wm.unsafe_fetch(index)]
    end

    # `kth` 番目に小さい値を返します。
    #
    # ```
    # wm = WaveletMatrix.new([1, 3, 2, 5])
    # wm.kth_smallest(0) # => 1
    # wm.kth_smallest(1) # => 2
    # wm.kth_smallest(2) # => 3
    # wm.kth_smallest(3) # => 5
    # ```
    def kth_smallest(kth : Int32)
      kth_smallest(.., kth)
    end

    # `range` の表す区間に含まれる要素のうち、`kth` 番目に小さい値を返します。
    #
    # ```
    # wm = WaveletMatrix.new([1, 3, 2, 5])
    # wm.kth_smallest(1..2, 0) # => 2
    # wm.kth_smallest(1..2, 1) # => 3
    # ```
    def kth_smallest(range : Range(Int?, Int?), kth : Int)
      @uniqued[@wm.kth_smallest(range, kth)]
    end

    # `kth` 番目に大きい値を返します。
    #
    # ```
    # wm = WaveletMatrix.new([1, 3, 2, 5])
    # wm.kth_smallest(0) # => 5
    # wm.kth_smallest(1) # => 3
    # wm.kth_smallest(2) # => 2
    # wm.kth_smallest(3) # => 1
    # ```
    def kth_largest(kth : Int32)
      kth_largest(.., kth)
    end

    # `range` の表す区間に含まれる要素のうち、`kth` 番目に大きい値を返します。
    #
    # ```
    # wm = WaveletMatrix.new([1, 3, 2, 5])
    # wm.kth_largest(1..2, 0) # => 3
    # wm.kth_largest(1..2, 1) # => 2
    # ```
    def kth_largest(range : Range(Int?, Int?), kth : Int)
      kth_smallest(range, range.size - kth - 1)
    end

    # `item` の個数を返します。
    #
    # 計算量は対数オーダーです。
    def count(item : T)
      count(.., item)
    end

    # `range` の表す区間に含まれる要素の中で `item` の個数を返します。
    def count(range : Range(Int?, Int?), item : T)
      count(range, item..item)
    end

    # `range` の表す区間に含まれる要素の中で `bound` が表す範囲の値の個数を返します。
    def count(range : Range(Int?, Int?), bound : Range(T?, T?)) : Int32
      lower_bound = (bound.begin || 0)
      upper_bound = (bound.end || T::MAX) + (bound.exclusive? || bound.end.nil? ? 0 : 1)
      count_less_eq(range, upper_bound) - count_less_eq(range, lower_bound)
    end

    # `range` の表す区間に含まれる要素のうち、`upper_bound` **未満** の値の最大値を返します。
    #
    # 存在しない場合は `nil` を返します。
    def prev_value(range : Range(Int?, Int?), upper_bound) : T?
      cnt = count_less_eq(range, upper_bound)
      cnt == 0 ? nil : kth_smallest(range, cnt - 1)
    end

    # `range` の表す区間に含まれる要素のうち、`lower_bound` **以上** の値の最小値を返します。
    #
    # 存在しない場合は `nil` を返します。
    def next_value(range : Range(Int?, Int?), lower_bound) : T?
      l = (range.begin || 0)
      r = (range.end || size) + (range.exclusive? || range.end.nil? ? 0 : 1)
      cnt = count_less_eq(range, lower_bound)
      cnt == (l...r).size ? nil : kth_smallest(range, cnt)
    end

    private def get_index(value : T)
      @uniqued.bsearch_index { |x| x >= value } || @uniqued.size
    end

    private def count_less_eq(range : Range(Int?, Int?), upper_bound) : Int32
      @wm.count_less_eq(range, get_index(upper_bound))
    end
  end
end
# require "./succinct_bit_vector.cr"
module NgLib
  # `SuccinctBitVector` は簡易ビットベクトル(簡易辞書、Succinct Indexable Dictionary)を提供するクラスです。
  #
  # 前計算 $O(n / 32)$ で次の操作が $O(1)$ くらいでできます。
  #
  # - `.[i]` # => $i$ 番目のビットにアクセスする ($O(1)$)
  # - `.sum(r)` # =>  $[0, r)$ にある $1$ の個数を求める ($O(1)$)
  # - `.kth_bit_index(k)` # => $k$ 番目に現れる $1$ の位置を求める ($O(\log{n})$)
  #
  # 例えばこの問題が解けます → [D - Sleep Log](https://atcoder.jp/contests/abc305/tasks/abc305_d)
  class SuccinctBitVector
    getter size : UInt32
    @blocks : UInt32
    @bits : Array(UInt32)
    @sums : Array(UInt32)
    @n_zeros : Int32
    @n_ones : Int32

    # 長さ $n$ のビット列を構築します。
    #
    # 計算量は $O(n / 32)$ です。
    def initialize(n : Int)
      @size = n.to_u32
      @blocks = (@size >> 5) + 1
      @bits = [0_u32] * @blocks
      @sums = [0_u32] * @blocks
      @n_zeros = 0
      @n_ones = 0
    end

    # 長さ $n$ のビット列を構築します。
    #
    # ブロックでは $i$ 番目のビットの値を返してください。
    #
    # 計算量は $O(n)$ です。
    def initialize(n : Int, & : -> Int)
      @size = n.to_u32
      @blocks = (@size >> 5) + 1
      @bits = [0_u32] * @blocks
      @sums = [0_u32] * @blocks

      @size.times do |i|
        set i if (yield i) == 1
      end

      @n_zeros = 0
      @n_ones = 0

      build
    end

    # 左から $i$ 番目のビットを 1 にします。
    #
    # 計算量は $O(1)$ です。
    def set(i : Int)
      @bits[i >> 5] |= 1_u32 << (i & 31)
    end

    # 左から $i$ 番目のビットを 0 にします。
    #
    # 計算量は $O(1)$ です。
    def reset(i : Int)
      @bits[i >> 5] &= ~(1_u32 << (i & 31))
    end

    # 総和が計算できるようにデータ構造を構築します。
    def build
      @sums[0] = 0
      (1...@blocks).each do |i|
        @sums[i] = @sums[i - 1] + @bits[i - 1].popcount
      end
      @n_zeros = @size.to_i - sum
      @n_ones = @size.to_i - @n_zeros
    end

    # $[0, n)$ の総和を返します。
    def sum
      sum(size)
    end

    # $[0, r)$ の総和を返します。
    def sum(r) : UInt32
      @sums[r >> 5] + (@bits[r >> 5] & ((1_u32 << (r & 31)) - 1)).popcount
    end

    # $[l, r)$ の総和を返します。
    def sum(l, r) : UInt32
      sum(r) - sum(l)
    end

    # `range` の範囲で総和を返します。
    def sum(range : Range(Int?, Int?))
      l = (range.begin || 0)
      r = (range.end || @size) + (range.exclusive? || range.end.nil? ? 0 : 1)
      sum(l, r)
    end

    def count_zeros : Int32
      @n_zeros
    end

    def count_zeros(range : Range(Int?, Int?)) : Int32
      l = (range.begin || 0)
      r = (range.end || @size) + (range.exclusive? || range.end.nil? ? 0 : 1)
      (l...r).size.to_i - sum(l, r).to_i
    end

    def count_ones : Int32
      @n_ones
    end

    def count_ones(range : Range(Int?, Int?)) : Int32
      l = (range.begin || 0)
      r = (range.end || @size) + (range.exclusive? || range.end.nil? ? 0 : 1)
      (l...r).size.to_i - count_zeros(range)
    end

    # $i$ 番目のビットを返します。
    def [](i) : UInt32
      ((@bits[i >> 5] >> (i & 31)) & 1) > 0 ? 1_u32 : 0_u32
    end

    # `range` の範囲の総和を返します。
    def [](range : Range(Int?, Int?)) : UInt32
      sum(range)
    end

    # $k$ 番目に出現する $1$ の位置を求めます。
    #
    # 言い換えると、$sum(i) = k$ となるような最小の $i$ を返します。
    #
    # 存在しない場合は `nil` を返します。
    #
    # 本当は $O(1)$ にできるらしいですが、面倒なので $O(\log{n})$ です。
    def kth_bit_index(k) : UInt32
      return 0_u32 if k == 0
      return nil if sum < k

      (0..length).bsearch { |right|
        sum(right) >= k
      } || raise "Not found"
    end
  end
end

module NgLib
  # **非負整数列** $A$ に対して、順序に関する様々なクエリに答えます。
  #
  # ジェネリクス T は `#[]` などでの返却値の型を指定するものであって、
  # 数列の値は非負整数でなければならないことに注意してください。
  #
  # 基本的には `CompressedWaveletMatrix` の方が高速です。
  class WaveletMatrix(T)
    include Indexable(T)

    @values : Array(UInt64)
    @n_bits : Int32
    @bit_vectors : Array(NgLib::SuccinctBitVector)

    delegate size, to: @values

    # 非負整数列 $A$ を `values` によって構築します。
    #
    # ```
    # WaveletMatrix.new([1, 3, 2, 5])
    # ```
    def initialize(values : Array(T))
      n = values.size
      @values = values.map(&.to_u64)
      @n_bits = log2_floor({1_u64, @values.max}.max) + 1
      @bit_vectors = Array.new(@n_bits) { NgLib::SuccinctBitVector.new(n) }

      cur = @values.clone
      nxt = Array.new(n) { 0_u64 }
      (@n_bits - 1).downto(0) do |height|
        n.times do |i|
          @bit_vectors[height].set(i) if cur[i].bit(height) == 1
        end

        @bit_vectors[height].build

        indices = [0, @bit_vectors[height].count_zeros]
        n.times do |i|
          bit = @bit_vectors[height][i]
          nxt[indices[bit]] = cur[i]
          indices[bit] += 1
        end

        cur, nxt = nxt, cur
      end
    end

    # 長さ $n$ の非負整数列 $A$ を構築します。
    #
    # ```
    # WaveletMatrix.new(10) { |i| (5 - i) ** 2 }
    # ```
    def self.new(n : Int, & : Int32 -> T)
      WaveletMatrix.new(Array.new(n) { |i| yield i })
    end

    # $A$ の `index` 番目の要素を取得します。
    #
    # ```
    # wm = WaveletMatrix.new([1, 3, 2, 5])
    # wm.unsafe_fetch(0) # => 1
    # wm.unsafe_fetch(1) # => 3
    # wm.unsafe_fetch(2) # => 2
    # wm.unsafe_fetch(3) # => 5
    # ```
    def unsafe_fetch(index : Int)
      @values.unsafe_fetch(index)
    end

    # `kth` 番目に小さい値を返します。
    #
    # ```
    # wm = WaveletMatrix.new([1, 3, 2, 5])
    # wm.kth_smallest(0) # => 1
    # wm.kth_smallest(1) # => 2
    # wm.kth_smallest(2) # => 3
    # wm.kth_smallest(3) # => 5
    # ```
    def kth_smallest(kth : Int32)
      kth_smallest(.., kth)
    end

    # `range` の表す区間に含まれる要素のうち、`kth` 番目に小さい値を返します。
    #
    # ```
    # wm = WaveletMatrix.new([1, 3, 2, 5])
    # wm.kth_smallest(1..2, 0) # => 2
    # wm.kth_smallest(1..2, 1) # => 3
    # ```
    def kth_smallest(range : Range(Int?, Int?), kth : Int)
      l = (range.begin || 0)
      r = (range.end || size) + (range.exclusive? || range.end.nil? ? 0 : 1)

      ret = T.zero
      (@n_bits - 1).downto(0) do |height|
        lzeros, rzeros = succ0(l, r, height)

        if kth < rzeros - lzeros
          l, r = lzeros, rzeros
        else
          kth -= rzeros - lzeros
          ret |= T.zero.succ << height
          l += @bit_vectors[height].count_zeros - lzeros
          r += @bit_vectors[height].count_zeros - rzeros
        end
      end

      ret
    end

    # `kth` 番目に大きい値を返します。
    #
    # ```
    # wm = WaveletMatrix.new([1, 3, 2, 5])
    # wm.kth_smallest(0) # => 5
    # wm.kth_smallest(1) # => 3
    # wm.kth_smallest(2) # => 2
    # wm.kth_smallest(3) # => 1
    # ```
    def kth_largest(kth : Int32)
      kth_largest(.., kth)
    end

    # `range` の表す区間に含まれる要素のうち、`kth` 番目に大きい値を返します。
    #
    # ```
    # wm = WaveletMatrix.new([1, 3, 2, 5])
    # wm.kth_largest(1..2, 0) # => 3
    # wm.kth_largest(1..2, 1) # => 2
    # ```
    def kth_largest(range : Range(Int?, Int?), kth : Int)
      kth_smallest(range, range.size - kth - 1)
    end

    # `item` の個数を返します。
    #
    # 計算量は対数オーダーです。
    def count(item : T)
      count(.., item)
    end

    # `range` の表す区間に含まれる要素の中で `item` の個数を返します。
    def count(range : Range(Int?, Int?), item : T)
      count(range, item..item)
    end

    # `range` の表す区間に含まれる要素の中で `bound` が表す範囲の値の個数を返します。
    def count(range : Range(Int?, Int?), bound : Range(T?, T?)) : Int32
      lower_bound = (bound.begin || 0)
      upper_bound = (bound.end || T::MAX) + (bound.exclusive? || bound.end.nil? ? 0 : 1)
      count_less_eq(range, upper_bound) - count_less_eq(range, lower_bound)
    end

    # `range` の表す区間に含まれる要素のうち、`upper_bound` **未満** の値の最大値を返します。
    #
    # 存在しない場合は `nil` を返します。
    def prev_value(range : Range(Int?, Int?), upper_bound) : T?
      cnt = count_less_eq(range, upper_bound)
      cnt == 0 ? nil : kth_smallest(range, cnt - 1)
    end

    # `range` の表す区間に含まれる要素のうち、`lower_bound` **以上** の値の最小値を返します。
    #
    # 存在しない場合は `nil` を返します。
    def next_value(range : Range(Int?, Int?), lower_bound) : T?
      l = (range.begin || 0)
      r = (range.end || size) + (range.exclusive? || range.end.nil? ? 0 : 1)
      cnt = count_less_eq(range, lower_bound)
      cnt == (l...r).size ? nil : kth_smallest(range, cnt)
    end

    # `range` の表す区間に含まれる要素のうち、`upper_bound` **未満** の値の個数を返します。
    #
    # 存在しない場合は `nil` を返します。
    def count_less_eq(range : Range(Int?, Int?), upper_bound) : Int32
      l = (range.begin || 0)
      r = (range.end || size) + (range.exclusive? || range.end.nil? ? 0 : 1)

      return (l...r).size.to_i if upper_bound >= (T.zero.succ << @n_bits)

      ret = 0
      (@n_bits - 1).downto(0) do |height|
        lzeros, rzeros = succ0(l, r, height)
        if upper_bound.bit(height) == 1
          ret += rzeros - lzeros
          l += @bit_vectors[height].count_zeros - lzeros
          r += @bit_vectors[height].count_zeros - rzeros
        else
          l, r = lzeros, rzeros
        end
      end

      ret
    end

    @[AlwaysInline]
    private def log2_floor(n : UInt64) : Int32
      log2_floor = 63 - n.leading_zeros_count
      log2_floor + ((n & n - 1) == 0 ? 0 : 1)
    end

    @[AlwaysInline]
    private def succ0(left : Int, right : Int, height : Int)
      lzeros = left <= 0 ? 0 : @bit_vectors[height].count_zeros(0...Math.min(left, size))
      rzeros = right <= 0 ? 0 : @bit_vectors[height].count_zeros(0...Math.min(right, size))
      {lzeros, rzeros}
    end
  end

  class CompressedWaveletMatrix(T)
    include Indexable(T)

    @wm : WaveletMatrix(Int32)
    @uniqued : Array(T)

    delegate size, to: @wm

    def initialize(values : Array(T))
      @uniqued = values.sort.uniq!
      @wm = WaveletMatrix.new(values.map { |elem| get_index(elem) })
    end

    def self.new(n : Int, & : Int32 -> T)
      CompressedWaveletMatrix.new(Array.new(n) { |i| yield i })
    end

    def unsafe_fetch(index : Int)
      @uniqued[@wm.unsafe_fetch(index)]
    end

    # `kth` 番目に小さい値を返します。
    #
    # ```
    # wm = WaveletMatrix.new([1, 3, 2, 5])
    # wm.kth_smallest(0) # => 1
    # wm.kth_smallest(1) # => 2
    # wm.kth_smallest(2) # => 3
    # wm.kth_smallest(3) # => 5
    # ```
    def kth_smallest(kth : Int32)
      kth_smallest(.., kth)
    end

    # `range` の表す区間に含まれる要素のうち、`kth` 番目に小さい値を返します。
    #
    # ```
    # wm = WaveletMatrix.new([1, 3, 2, 5])
    # wm.kth_smallest(1..2, 0) # => 2
    # wm.kth_smallest(1..2, 1) # => 3
    # ```
    def kth_smallest(range : Range(Int?, Int?), kth : Int)
      @uniqued[@wm.kth_smallest(range, kth)]
    end

    # `kth` 番目に大きい値を返します。
    #
    # ```
    # wm = WaveletMatrix.new([1, 3, 2, 5])
    # wm.kth_smallest(0) # => 5
    # wm.kth_smallest(1) # => 3
    # wm.kth_smallest(2) # => 2
    # wm.kth_smallest(3) # => 1
    # ```
    def kth_largest(kth : Int32)
      kth_largest(.., kth)
    end

    # `range` の表す区間に含まれる要素のうち、`kth` 番目に大きい値を返します。
    #
    # ```
    # wm = WaveletMatrix.new([1, 3, 2, 5])
    # wm.kth_largest(1..2, 0) # => 3
    # wm.kth_largest(1..2, 1) # => 2
    # ```
    def kth_largest(range : Range(Int?, Int?), kth : Int)
      kth_smallest(range, range.size - kth - 1)
    end

    # `item` の個数を返します。
    #
    # 計算量は対数オーダーです。
    def count(item : T)
      count(.., item)
    end

    # `range` の表す区間に含まれる要素の中で `item` の個数を返します。
    def count(range : Range(Int?, Int?), item : T)
      count(range, item..item)
    end

    # `range` の表す区間に含まれる要素の中で `bound` が表す範囲の値の個数を返します。
    def count(range : Range(Int?, Int?), bound : Range(T?, T?)) : Int32
      lower_bound = (bound.begin || 0)
      upper_bound = (bound.end || T::MAX) + (bound.exclusive? || bound.end.nil? ? 0 : 1)
      count_less_eq(range, upper_bound) - count_less_eq(range, lower_bound)
    end

    # `range` の表す区間に含まれる要素のうち、`upper_bound` **未満** の値の最大値を返します。
    #
    # 存在しない場合は `nil` を返します。
    def prev_value(range : Range(Int?, Int?), upper_bound) : T?
      cnt = count_less_eq(range, upper_bound)
      cnt == 0 ? nil : kth_smallest(range, cnt - 1)
    end

    # `range` の表す区間に含まれる要素のうち、`lower_bound` **以上** の値の最小値を返します。
    #
    # 存在しない場合は `nil` を返します。
    def next_value(range : Range(Int?, Int?), lower_bound) : T?
      l = (range.begin || 0)
      r = (range.end || size) + (range.exclusive? || range.end.nil? ? 0 : 1)
      cnt = count_less_eq(range, lower_bound)
      cnt == (l...r).size ? nil : kth_smallest(range, cnt)
    end

    private def get_index(value : T)
      @uniqued.bsearch_index { |x| x >= value } || @uniqued.size
    end

    private def count_less_eq(range : Range(Int?, Int?), upper_bound) : Int32
      @wm.count_less_eq(range, get_index(upper_bound))
    end
  end
end
Back to top page