ngng628's Library

This documentation is automatically generated by online-judge-tools/verification-helper

:warning: src/nglib/data_structure/priority_sum.cr

Depends on

Required by

Code

require "./aatree_multiset"

module NgLib
  # 昇順(降順) $k$ 個の総和を効率良く求めるためのデータ構造です。
  #
  # 値の追加、削除、$k$ の変更ができます。
  class PrioritySum(T)
    getter k : Int32
    getter sum : T

    @tag : Symbol
    @mset : NgLib::AATreeMultiset(T)
    delegate size, to: @mset
    delegate empty?, to: @mset

    # 下位 $k$ 要素の総和を求めるためのデータ構造を構築します。
    def self.min(k : Int, initial : T = T.zero)
      self.new(:min, k, initial)
    end

    # 上位 $k$ 要素の総和を求めるためのデータ構造を構築します。
    def self.max(k : Int, initial : T = T.zero)
      self.new(:max, k, initial)
    end

    def initialize(@tag : Symbol, k : Int, initial : T = T.zero)
      @k = k.to_i32
      @sum = initial
      @mset = NgLib::AATreeMultiset(T).new
    end

    # 要素 $x$ をデータ構造に追加します。
    #
    # 計算量は $O(\log{n})$ です。
    def add(x : T)
      if size < @k
        @sum += x
      else
        kth = @mset.at(kth_index(@k - 1))
        @sum = @sum - kth + x if cmp(x, kth)
      end
      @mset << x
    end

    # Alias for `#add`
    def <<(x : T)
      add(x)
    end

    # 要素 $x$ をデータ構造から削除します。
    #
    # 計算量は $O(\log{n})$ です。
    def delete(x : T)
      if size <= @k
        @sum -= x
        @mset.delete(x)
      else
        kth = @mset.at(kth_index(@k))
        @sum -= x if cmp(x, kth)
        @mset.delete(x)

        kth2 = @mset.at(kth_index(@k - 1))
        @sum += kth2 if cmp(x, kth)
      end
    end

    # $k$ の値を変更します。
    #
    # 計算量は $\Delta k \log{\Delta k}$
    def k=(k : Int)
      if @k < k
        (k - @k).times do |i|
          break if k + i >= size
          @sum += @mset.at(kth_index(k + i))
        end
      elsif @k > k
        (@k - k).times do |i|
          next if @k - i - 1 >= size
          break if @k - i < 1
          @sum -= @mset.at(kth_index(@k - i - 1))
        end
      end
      @k = k.to_i32
    end

    private def kth_index(k : Int)
      case @tag
      when :max
        @mset.size - k - 1
      when :min
        k
      else
        raise IndexError.new
      end
    end

    private def cmp(a : T, b : T)
      case @tag
      when :max
        a > b
      when :min
        a < b
      end
    end
  end
end
# require "./aatree_multiset"
module NgLib
  # 順序付き多重集合です。
  #
  # 平衡二分探索木として [AA木](https://ja.wikipedia.org/wiki/AA%E6%9C%A8) を使用しています。
  # 性能は赤黒木の方が良いことが多い気がします。
  #
  # C++の標準ライブラリの `multiset` と違って、$k$ 番目の値が取り出せることなどが魅力的です。
  class AATreeMultiset(T)
    include Enumerable(T)

    private class Node(T)
      property left : Node(T)?
      property right : Node(T)?
      property parent : Node(T)?
      property key : T
      property level : Int32
      property size : Int32

      def initialize(val : T)
        @left = @right = @parent = nil
        @level = 1
        @key = val
        @size = 1
      end

      def rotate_left : Node(T)
        right = @right.not_nil!
        mid = right.left
        par = @parent
        if right.parent = par
          if par.not_nil!.left == self
            par.not_nil!.left = right
          else
            par.not_nil!.right = right
          end
        end
        mid.parent = self if @right = mid
        right.left = self
        @parent = right

        sz = @size
        @size += (mid ? mid.size : 0) - right.size
        right.size = sz

        right
      end

      def rotate_right : Node(T)
        left = @left.not_nil!
        mid = left.right
        par = @parent

        if left.not_nil!.parent = par
          if par.not_nil!.left == self
            par.not_nil!.left = left
          else
            par.not_nil!.right = left
          end
        end
        mid.parent = self if @left = mid
        left.not_nil!.right = self
        @parent = left

        sz = @size
        @size += (mid ? mid.size : 0) - left.size
        left.size = sz

        left
      end

      def left_side?(node : Node(T)?) : Bool
        @left == node
      end

      def assign(node : Node(T)) : T
        @key = node.key
      end
    end

    @root : Node(T)?

    private def find_node(node : Node(T)?, val : T) : Node(T)?
      return nil unless node
      until val == node.not_nil!.key
        if val < node.not_nil!.key
          break unless node.not_nil!.left
          node = node.not_nil!.left
        else
          break unless node.not_nil!.right
          node = node.not_nil!.right
        end
      end

      while node.not_nil!.left && node.not_nil!.left.not_nil!.key == val
        node = node.not_nil!.left
      end
      while node.not_nil!.right && node.not_nil!.right.not_nil!.key == val
        node = node.not_nil!.right
      end

      node
    end

    private def find_node2(node : Node(T)?, val : T) : Node(T)?
      return nil unless node
      loop do
        if val <= node.not_nil!.key
          break unless node.not_nil!.left
          node = node.not_nil!.left
        else
          break unless node.not_nil!.right
          node = node.not_nil!.right
        end
      end

      node
    end

    private def skew(node : Node(T)?) : Node(T)?
      return nil unless node
      left = node.not_nil!.left
      if left && node.not_nil!.level == left.not_nil!.level
        return node.not_nil!.rotate_right
      end
      node
    end

    private def split(node : Node(T)?) : Node(T)?
      return nil unless node
      right = node.right
      if right && right.not_nil!.right && node.level == right.not_nil!.right.not_nil!.level
        r = node.rotate_left
        r.level += 1
        return r
      end
      node
    end

    private def begin_node : Node(T)?
      return nil unless @root
      node = @root
      while node.not_nil!.left
        node = node.not_nil!.left
      end
      node
    end

    private def next_node(node : Node(T)) : Node(T)?
      if node.right
        node = node.right
        while node.not_nil!.left
          node = node.not_nil!.left
        end
        node
      else
        while node
          par = node.not_nil!.parent
          if par && par.not_nil!.left_side?(node)
            return par
          end
          node = par
        end
        node
      end
    end

    private def level(node : Node(T)?)
      node ? node.level : 0
    end

    def initialize
      @root = nil
      self
    end

    def initialize(enumerable : Enumerable(T))
      @root = nil
      concat(enumerable)
      self
    end

    def concat(elems) : self
      elems.each { |elem| self << elem }
      self
    end

    def includes?(val : T) : Bool
      node = find_node(@root, val)
      node.nil? ? false : node.key == val
    end

    def clear
      @root = nil
    end

    def empty? : Bool
      @root.nil?
    end

    def at(k : Int) : T
      raise IndexError.new unless 0 <= k && k < size
      node = @root
      k += 1
      loop do
        left_size = (node.not_nil!.left ? node.not_nil!.left.not_nil!.size : 0) + 1
        break if left_size == k

        if k < left_size
          node = node.not_nil!.left
        else
          node = node.not_nil!.right
          k -= left_size
        end
      end
      node.not_nil!.key
    end

    def at?(k : Int) : T?
      return nil unless 0 <= k && k < size
      at(k)
    end

    def each(& : T ->)
      node = begin_node
      while node
        yield node.not_nil!.key
        node = next_node(node.not_nil!)
      end
    end

    def add(val : T) : Nil
      add?(val)
      nil
    end

    def add?(val : T) : Bool
      unless @root
        @root = Node.new(val)
        return true
      end

      node = find_node2(@root, val)

      new_node = Node.new(val)
      if val <= node.not_nil!.key
        node.not_nil!.left = new_node
      else
        node.not_nil!.right = new_node
      end
      new_node.not_nil!.parent = node

      node = new_node
      while node
        node = split(skew(node))
        unless node.not_nil!.parent
          @root = node
          break
        end
        node = node.not_nil!.parent
        node.not_nil!.size += 1
      end
      true
    end

    def delete(val : T) : Bool
      return false unless @root

      node = find_node(@root, val)
      return false unless node.not_nil!.key == val

      if node.not_nil!.left || node.not_nil!.right
        child = find_node(node.not_nil!.left ? node.not_nil!.left : node.not_nil!.right, val)
        node.not_nil!.assign(child.not_nil!)
        node = child
      end

      par = node.not_nil!.parent
      if par
        if par.not_nil!.left_side?(node)
          par.left = nil
        else
          par.right = nil
        end
      else
        @root = nil
      end
      node = par

      while node
        new_level = {level(node.left), level(node.right)}.min + 1
        if new_level < node.level
          node.level = new_level
          if new_level < level(node.right)
            node.right.not_nil!.level = new_level
          end
        end

        node.size -= 1
        node = skew(node).not_nil!
        skew(node.right.not_nil!.right) if skew(node.right)

        node = split(node)
        split(node.not_nil!.right)

        unless node.not_nil!.parent
          @root = node
          break
        end
        node = node.not_nil!.parent
      end
      true
    end

    def delete_at(k : Int) : Bool
      delete(at(k))
    end

    def delete_at?(k : Int) : Bool
      val = at?(k)
      if val
        delete(val)
      else
        false
      end
    end

    def lower_bound_index(val : T) : Int32
      node = @root
      return 0 unless node
      index = 0
      while node
        if val <= node.not_nil!.key
          node = node.not_nil!.left
        else
          index += (node.not_nil!.left ? node.not_nil!.left.not_nil!.size : 0) + 1
          node = node.not_nil!.right
        end
      end
      index
    end

    def upper_bound_index(val : T) : Int32
      node = @root
      return 0 unless node
      index = 0
      while node
        if val < node.not_nil!.key
          node = node.not_nil!.left
        else
          index += (node.not_nil!.left ? node.not_nil!.left.not_nil!.size : 0) + 1
          node = node.not_nil!.right
        end
      end
      index
    end

    def less_index(val : T) : Int32?
      index = lower_bound_index(val)
      index == 0 ? nil : index - 1
    end

    def less_equal_index(val : T) : Int32?
      index = lower_bound_index(val)
      val == at?(index) ? index : (index == 0 ? nil : index - 1)
    end

    def greater_index(val : T) : Int32?
      index = upper_bound_index(val)
      index == size ? nil : index
    end

    def greater_equal_index(val : T) : Int32?
      index = lower_bound_index(val)
      index == size ? nil : index
    end

    def first : T
      at(0)
    end

    def first? : T?
      at?(0)
    end

    def last : T
      at(size - 1)
    end

    def last? : T?
      at?(size - 1)
    end

    def count(val : T) : Int32
      upper_bound_index(val) - lower_bound_index(val)
    end

    def size : Int32
      @root ? @root.not_nil!.size : 0
    end

    def to_a : Array(T)
      res = Array(T).new
      return res unless @root
      dfs = uninitialized Node(T) -> Nil
      dfs = ->(node : Node(T)) do
        dfs.call(node.left.not_nil!) if node.left
        res << node.key
        dfs.call(node.right.not_nil!) if node.right
        nil
      end
      dfs.call(@root.not_nil!)
      res
    end

    def to_s(io : IO) : Nil
      io << "{" + to_a.join(", ") + "}"
    end

    def inspect(io : IO) : Nil
      to_s(io)
    end

    def ==(other : AATreeSet(T)) : Bool
      self.to_a == other.to_a
    end

    def <<(val : T) : Bool
      add?(val)
    end

    def [](k : Int) : T
      at(k)
    end

    def []?(k : Int) : T | Nil
      at?(k)
    end
  end
end

module NgLib
  # 昇順(降順) $k$ 個の総和を効率良く求めるためのデータ構造です。
  #
  # 値の追加、削除、$k$ の変更ができます。
  class PrioritySum(T)
    getter k : Int32
    getter sum : T

    @tag : Symbol
    @mset : NgLib::AATreeMultiset(T)
    delegate size, to: @mset
    delegate empty?, to: @mset

    # 下位 $k$ 要素の総和を求めるためのデータ構造を構築します。
    def self.min(k : Int, initial : T = T.zero)
      self.new(:min, k, initial)
    end

    # 上位 $k$ 要素の総和を求めるためのデータ構造を構築します。
    def self.max(k : Int, initial : T = T.zero)
      self.new(:max, k, initial)
    end

    def initialize(@tag : Symbol, k : Int, initial : T = T.zero)
      @k = k.to_i32
      @sum = initial
      @mset = NgLib::AATreeMultiset(T).new
    end

    # 要素 $x$ をデータ構造に追加します。
    #
    # 計算量は $O(\log{n})$ です。
    def add(x : T)
      if size < @k
        @sum += x
      else
        kth = @mset.at(kth_index(@k - 1))
        @sum = @sum - kth + x if cmp(x, kth)
      end
      @mset << x
    end

    # Alias for `#add`
    def <<(x : T)
      add(x)
    end

    # 要素 $x$ をデータ構造から削除します。
    #
    # 計算量は $O(\log{n})$ です。
    def delete(x : T)
      if size <= @k
        @sum -= x
        @mset.delete(x)
      else
        kth = @mset.at(kth_index(@k))
        @sum -= x if cmp(x, kth)
        @mset.delete(x)

        kth2 = @mset.at(kth_index(@k - 1))
        @sum += kth2 if cmp(x, kth)
      end
    end

    # $k$ の値を変更します。
    #
    # 計算量は $\Delta k \log{\Delta k}$
    def k=(k : Int)
      if @k < k
        (k - @k).times do |i|
          break if k + i >= size
          @sum += @mset.at(kth_index(k + i))
        end
      elsif @k > k
        (@k - k).times do |i|
          next if @k - i - 1 >= size
          break if @k - i < 1
          @sum -= @mset.at(kth_index(@k - i - 1))
        end
      end
      @k = k.to_i32
    end

    private def kth_index(k : Int)
      case @tag
      when :max
        @mset.size - k - 1
      when :min
        k
      else
        raise IndexError.new
      end
    end

    private def cmp(a : T, b : T)
      case @tag
      when :max
        a > b
      when :min
        a < b
      end
    end
  end
end
Back to top page