ngng628's Library

This documentation is automatically generated by online-judge-tools/verification-helper

:warning: src/nglib/data_structure/aatree_map.cr

Required by

Code

module NgLib
  # 順序付き連想配列です。
  #
  # 平衡二分探索木として [AA木](https://ja.wikipedia.org/wiki/AA%E6%9C%A8) を使用しています。
  # 性能は赤黒木の方が良いことが多い気がします。
  #
  # C++の標準ライブラリの `multiset` と違って、$k$ 番目の値が取り出せることなどが魅力的です。
  class AATreeMap(K, V)
    include Enumerable({K, V})

    private class Node(K, V)
      property left : Node(K, V)?
      property right : Node(K, V)?
      property parent : Node(K, V)?
      property key : K
      property value : V
      property level : Int32
      property size : Int32

      def initialize(item : {K, V})
        @left = @right = @parent = nil
        @level = 1
        @key = item[0]
        @value = item[1]
        @size = 1
      end

      def rotate_left : Node(K, V)
        right = @right.not_nil!
        mid = right.left
        par = @parent
        if right.parent = par
          if par.not_nil!.left == self
            par.not_nil!.left = right
          else
            par.not_nil!.right = right
          end
        end
        mid.parent = self if @right = mid
        right.left = self
        @parent = right

        sz = @size
        @size += (mid ? mid.size : 0) - right.size
        right.size = sz

        right
      end

      def rotate_right : Node(K, V)
        left = @left.not_nil!
        mid = left.right
        par = @parent

        if left.not_nil!.parent = par
          if par.not_nil!.left == self
            par.not_nil!.left = left
          else
            par.not_nil!.right = left
          end
        end
        mid.parent = self if @left = mid
        left.not_nil!.right = self
        @parent = left

        sz = @size
        @size += (mid ? mid.size : 0) - left.size
        left.size = sz

        left
      end

      def left_side?(node : Node(K, V)?) : Bool
        @left == node
      end

      def assign(node : Node(K, V)) : V
        @key = node.key
        @value = node.value
      end
    end

    @root : Node(K, V)?
    @default : V?

    private def find_node(node : Node(K, V)?, key : K) : Node(K, V)?
      return nil unless node
      until key == node.not_nil!.key
        if key < node.not_nil!.key
          break unless node.not_nil!.left
          node = node.not_nil!.left
        else
          break unless node.not_nil!.right
          node = node.not_nil!.right
        end
      end
      node
    end

    private def skew(node : Node(K, V)?) : Node(K, V)?
      return nil unless node
      left = node.not_nil!.left
      if left && node.not_nil!.level == left.not_nil!.level
        return node.not_nil!.rotate_right
      end
      node
    end

    private def split(node : Node(K, V)?) : Node(K, V)?
      return nil unless node
      right = node.right
      if right && right.not_nil!.right && node.level == right.not_nil!.right.not_nil!.level
        r = node.rotate_left
        r.level += 1
        return r
      end
      node
    end

    private def upsert(key : K, value : V) : Nil
      unless @root
        @root = Node.new({key, value})
        return true
      end

      node = find_node(@root, key)
      if node.not_nil!.key == key
        node.not_nil!.value = value
        return
      end

      new_node = Node.new({key, value})
      if key < node.not_nil!.key
        node.not_nil!.left = new_node
      else
        node.not_nil!.right = new_node
      end
      new_node.not_nil!.parent = node

      node = new_node
      while node
        node = split(skew(node))
        unless node.not_nil!.parent
          @root = node
          break
        end
        node = node.not_nil!.parent
        node.not_nil!.size += 1
      end
    end

    private def begin_node : Node(K, V)?
      return nil unless @root
      node = @root
      while node.not_nil!.left
        node = node.not_nil!.left
      end
      node
    end

    private def next_node(node : Node(K, V)) : Node(K, V)?
      if node.right
        node = node.right
        while node.not_nil!.left
          node = node.not_nil!.left
        end
        node
      else
        while node
          par = node.not_nil!.parent
          if par && par.not_nil!.left_side?(node)
            return par
          end
          node = par
        end
        node
      end
    end

    private def level(node : Node(K, V)?)
      node ? node.level : 0
    end

    def initialize
      @root = nil
      @default = nil
      self
    end

    def initialize(@default : V)
      @root = nil
      self
    end

    def initialize(enumerable : Enumerable({K, V}))
      @root = nil
      concat(enumerable)
      self
    end

    def concat(elems) : self
      elems.each { |elem| self << elem }
      self
    end

    def includes?(key : K, value : V) : Bool
      node = find_node(@root, key)
      node.nil? ? false : node.key == key && node.value == value
    end

    def clear
      @root = nil
    end

    def empty? : Bool
      @root.nil?
    end

    def at(k : Int) : {K, V}
      k += size if k < 0
      raise IndexError.new unless 0 <= k && k < size
      node = @root
      k += 1
      loop do
        left_size = (node.not_nil!.left ? node.not_nil!.left.not_nil!.size : 0) + 1
        break if left_size == k

        if k < left_size
          node = node.not_nil!.left
        else
          node = node.not_nil!.right
          k -= left_size
        end
      end
      {node.not_nil!.key, node.not_nil!.value}
    end

    def at?(k : Int) : {K, V}?
      k += size if k < 0
      return nil unless 0 <= k && k < size
      at(k)
    end

    def key_at(k : Int) : K
      at(k)[0]
    end

    def key_at?(k : Int) : K?
      t = at?(k); t ? t[0] : nil
    end

    def value_at(k : Int) : V
      at(k)[1]
    end

    def value_at?(k : Int) : V?
      t = at?(k); t ? t[1] : nil
    end

    def each_key(& : K ->)
      each do |key, _|
        yield key
      end
    end

    def each_value(& : V ->)
      each do |_, value|
        yield value
      end
    end

    def each(& : {K, V} ->)
      node = begin_node
      while node
        pr = {node.not_nil!.key, node.not_nil!.value}
        yield pr
        node = next_node(node.not_nil!)
      end
    end

    def keys : Array(K)
      res = Array(K).new
      each do |key, _|
        res << key
      end
      res
    end

    def values : Array(V)
      res = Array(V).new
      each do |_, value|
        res << value
      end
      res
    end

    def delete_key(key : K) : Bool
      return false unless @root

      node = find_node(@root, key)
      return false unless node.not_nil!.key == key

      if node.not_nil!.left || node.not_nil!.right
        child = find_node(node.not_nil!.left ? node.not_nil!.left : node.not_nil!.right, key)
        node.not_nil!.assign(child.not_nil!)
        node = child
      end

      par = node.not_nil!.parent
      if par
        if par.not_nil!.left_side?(node)
          par.left = nil
        else
          par.right = nil
        end
      else
        @root = nil
      end
      node = par

      while node
        new_level = {level(node.left), level(node.right)}.min + 1
        if new_level < node.level
          node.level = new_level
          if new_level < level(node.right)
            node.right.not_nil!.level = new_level
          end
        end

        node.size -= 1
        node = skew(node).not_nil!
        skew(node.right.not_nil!.right) if skew(node.right)

        node = split(node)
        split(node.not_nil!.right)

        unless node.not_nil!.parent
          @root = node
          break
        end
        node = node.not_nil!.parent
      end
      true
    end

    def delete_at(k : Int)
      key = key_at(k)
      delete_key(key)
    end

    def delete_at(k : Int)
      key = key_at?(k)
      return if key.nil?
      delete_key(key)
    end

    def has_key?(key : K) : Bool
      return false unless @root
      node = find_node(@root, key)
      node.nil? ? false : node.key == key
    end

    def lower_bound_index(key : K) : Int32
      node = @root
      return 0 unless node
      index = 0
      while node
        if key <= node.not_nil!.key
          node = node.not_nil!.left
        else
          index += (node.not_nil!.left ? node.not_nil!.left.not_nil!.size : 0) + 1
          node = node.not_nil!.right
        end
      end
      index
    end

    def upper_bound_index(key : K) : Int32
      node = @root
      return 0 unless node
      index = 0
      while node
        if key < node.not_nil!.key
          node = node.not_nil!.left
        else
          index += (node.not_nil!.left ? node.not_nil!.left.not_nil!.size : 0) + 1
          node = node.not_nil!.right
        end
      end
      index
    end

    def less_index(key : K) : Int32?
      index = lower_bound_index(key)
      index == 0 ? nil : index - 1
    end

    def less_equal_index(key : K) : Int32?
      index = lower_bound_index(key)
      key == at?(index) ? index : (index == 0 ? nil : index - 1)
    end

    def greater_index(key : K) : Int32?
      index = upper_bound_index(key)
      index == size ? nil : index
    end

    def greater_equal_index(key : K) : Int32?
      index = lower_bound_index(key)
      index == size ? nil : index
    end

    def size : Int32
      @root ? @root.not_nil!.size : 0
    end

    def to_a : Array({K, V})
      res = Array({K, V}).new
      return res unless @root
      dfs = uninitialized Node(K, V) -> Nil
      dfs = ->(node : Node(K, V)) do
        dfs.call(node.left.not_nil!) if node.left
        res << {node.key, node.value}
        dfs.call(node.right.not_nil!) if node.right
        nil
      end
      dfs.call(@root.not_nil!)
      res
    end

    def to_s(io : IO) : Nil
      io << "{" + to_a.map { |key, value| "#{key} => #{value}" }.join(", ") + "}"
    end

    def inspect(io : IO)
      to_s(io)
    end

    def <<(item : {K, V}) : Nil
      upsert(item[0], item[1])
    end

    def [](key : K) : V
      return @default.not_nil! if @root.nil? && !@default.nil?
      raise KeyError.new "Missing key: #{key.inspect}" unless @root
      node = find_node(@root, key)
      return @default.not_nil! if node.not_nil!.key != key && !@default.nil?
      raise KeyError.new "Missing key: #{key.inspect}" if node.not_nil!.key != key
      node.not_nil!.value
    end

    def []?(key : K) : V?
      return @default if @root.nil?
      node = find_node(@root, key)
      return @default if node.not_nil!.key != key
      node.not_nil!.value
    end

    def []=(key : K, value : V) : V
      upsert(key, value)
      value
    end
  end
end
module NgLib
  # 順序付き連想配列です。
  #
  # 平衡二分探索木として [AA木](https://ja.wikipedia.org/wiki/AA%E6%9C%A8) を使用しています。
  # 性能は赤黒木の方が良いことが多い気がします。
  #
  # C++の標準ライブラリの `multiset` と違って、$k$ 番目の値が取り出せることなどが魅力的です。
  class AATreeMap(K, V)
    include Enumerable({K, V})

    private class Node(K, V)
      property left : Node(K, V)?
      property right : Node(K, V)?
      property parent : Node(K, V)?
      property key : K
      property value : V
      property level : Int32
      property size : Int32

      def initialize(item : {K, V})
        @left = @right = @parent = nil
        @level = 1
        @key = item[0]
        @value = item[1]
        @size = 1
      end

      def rotate_left : Node(K, V)
        right = @right.not_nil!
        mid = right.left
        par = @parent
        if right.parent = par
          if par.not_nil!.left == self
            par.not_nil!.left = right
          else
            par.not_nil!.right = right
          end
        end
        mid.parent = self if @right = mid
        right.left = self
        @parent = right

        sz = @size
        @size += (mid ? mid.size : 0) - right.size
        right.size = sz

        right
      end

      def rotate_right : Node(K, V)
        left = @left.not_nil!
        mid = left.right
        par = @parent

        if left.not_nil!.parent = par
          if par.not_nil!.left == self
            par.not_nil!.left = left
          else
            par.not_nil!.right = left
          end
        end
        mid.parent = self if @left = mid
        left.not_nil!.right = self
        @parent = left

        sz = @size
        @size += (mid ? mid.size : 0) - left.size
        left.size = sz

        left
      end

      def left_side?(node : Node(K, V)?) : Bool
        @left == node
      end

      def assign(node : Node(K, V)) : V
        @key = node.key
        @value = node.value
      end
    end

    @root : Node(K, V)?
    @default : V?

    private def find_node(node : Node(K, V)?, key : K) : Node(K, V)?
      return nil unless node
      until key == node.not_nil!.key
        if key < node.not_nil!.key
          break unless node.not_nil!.left
          node = node.not_nil!.left
        else
          break unless node.not_nil!.right
          node = node.not_nil!.right
        end
      end
      node
    end

    private def skew(node : Node(K, V)?) : Node(K, V)?
      return nil unless node
      left = node.not_nil!.left
      if left && node.not_nil!.level == left.not_nil!.level
        return node.not_nil!.rotate_right
      end
      node
    end

    private def split(node : Node(K, V)?) : Node(K, V)?
      return nil unless node
      right = node.right
      if right && right.not_nil!.right && node.level == right.not_nil!.right.not_nil!.level
        r = node.rotate_left
        r.level += 1
        return r
      end
      node
    end

    private def upsert(key : K, value : V) : Nil
      unless @root
        @root = Node.new({key, value})
        return true
      end

      node = find_node(@root, key)
      if node.not_nil!.key == key
        node.not_nil!.value = value
        return
      end

      new_node = Node.new({key, value})
      if key < node.not_nil!.key
        node.not_nil!.left = new_node
      else
        node.not_nil!.right = new_node
      end
      new_node.not_nil!.parent = node

      node = new_node
      while node
        node = split(skew(node))
        unless node.not_nil!.parent
          @root = node
          break
        end
        node = node.not_nil!.parent
        node.not_nil!.size += 1
      end
    end

    private def begin_node : Node(K, V)?
      return nil unless @root
      node = @root
      while node.not_nil!.left
        node = node.not_nil!.left
      end
      node
    end

    private def next_node(node : Node(K, V)) : Node(K, V)?
      if node.right
        node = node.right
        while node.not_nil!.left
          node = node.not_nil!.left
        end
        node
      else
        while node
          par = node.not_nil!.parent
          if par && par.not_nil!.left_side?(node)
            return par
          end
          node = par
        end
        node
      end
    end

    private def level(node : Node(K, V)?)
      node ? node.level : 0
    end

    def initialize
      @root = nil
      @default = nil
      self
    end

    def initialize(@default : V)
      @root = nil
      self
    end

    def initialize(enumerable : Enumerable({K, V}))
      @root = nil
      concat(enumerable)
      self
    end

    def concat(elems) : self
      elems.each { |elem| self << elem }
      self
    end

    def includes?(key : K, value : V) : Bool
      node = find_node(@root, key)
      node.nil? ? false : node.key == key && node.value == value
    end

    def clear
      @root = nil
    end

    def empty? : Bool
      @root.nil?
    end

    def at(k : Int) : {K, V}
      k += size if k < 0
      raise IndexError.new unless 0 <= k && k < size
      node = @root
      k += 1
      loop do
        left_size = (node.not_nil!.left ? node.not_nil!.left.not_nil!.size : 0) + 1
        break if left_size == k

        if k < left_size
          node = node.not_nil!.left
        else
          node = node.not_nil!.right
          k -= left_size
        end
      end
      {node.not_nil!.key, node.not_nil!.value}
    end

    def at?(k : Int) : {K, V}?
      k += size if k < 0
      return nil unless 0 <= k && k < size
      at(k)
    end

    def key_at(k : Int) : K
      at(k)[0]
    end

    def key_at?(k : Int) : K?
      t = at?(k); t ? t[0] : nil
    end

    def value_at(k : Int) : V
      at(k)[1]
    end

    def value_at?(k : Int) : V?
      t = at?(k); t ? t[1] : nil
    end

    def each_key(& : K ->)
      each do |key, _|
        yield key
      end
    end

    def each_value(& : V ->)
      each do |_, value|
        yield value
      end
    end

    def each(& : {K, V} ->)
      node = begin_node
      while node
        pr = {node.not_nil!.key, node.not_nil!.value}
        yield pr
        node = next_node(node.not_nil!)
      end
    end

    def keys : Array(K)
      res = Array(K).new
      each do |key, _|
        res << key
      end
      res
    end

    def values : Array(V)
      res = Array(V).new
      each do |_, value|
        res << value
      end
      res
    end

    def delete_key(key : K) : Bool
      return false unless @root

      node = find_node(@root, key)
      return false unless node.not_nil!.key == key

      if node.not_nil!.left || node.not_nil!.right
        child = find_node(node.not_nil!.left ? node.not_nil!.left : node.not_nil!.right, key)
        node.not_nil!.assign(child.not_nil!)
        node = child
      end

      par = node.not_nil!.parent
      if par
        if par.not_nil!.left_side?(node)
          par.left = nil
        else
          par.right = nil
        end
      else
        @root = nil
      end
      node = par

      while node
        new_level = {level(node.left), level(node.right)}.min + 1
        if new_level < node.level
          node.level = new_level
          if new_level < level(node.right)
            node.right.not_nil!.level = new_level
          end
        end

        node.size -= 1
        node = skew(node).not_nil!
        skew(node.right.not_nil!.right) if skew(node.right)

        node = split(node)
        split(node.not_nil!.right)

        unless node.not_nil!.parent
          @root = node
          break
        end
        node = node.not_nil!.parent
      end
      true
    end

    def delete_at(k : Int)
      key = key_at(k)
      delete_key(key)
    end

    def delete_at(k : Int)
      key = key_at?(k)
      return if key.nil?
      delete_key(key)
    end

    def has_key?(key : K) : Bool
      return false unless @root
      node = find_node(@root, key)
      node.nil? ? false : node.key == key
    end

    def lower_bound_index(key : K) : Int32
      node = @root
      return 0 unless node
      index = 0
      while node
        if key <= node.not_nil!.key
          node = node.not_nil!.left
        else
          index += (node.not_nil!.left ? node.not_nil!.left.not_nil!.size : 0) + 1
          node = node.not_nil!.right
        end
      end
      index
    end

    def upper_bound_index(key : K) : Int32
      node = @root
      return 0 unless node
      index = 0
      while node
        if key < node.not_nil!.key
          node = node.not_nil!.left
        else
          index += (node.not_nil!.left ? node.not_nil!.left.not_nil!.size : 0) + 1
          node = node.not_nil!.right
        end
      end
      index
    end

    def less_index(key : K) : Int32?
      index = lower_bound_index(key)
      index == 0 ? nil : index - 1
    end

    def less_equal_index(key : K) : Int32?
      index = lower_bound_index(key)
      key == at?(index) ? index : (index == 0 ? nil : index - 1)
    end

    def greater_index(key : K) : Int32?
      index = upper_bound_index(key)
      index == size ? nil : index
    end

    def greater_equal_index(key : K) : Int32?
      index = lower_bound_index(key)
      index == size ? nil : index
    end

    def size : Int32
      @root ? @root.not_nil!.size : 0
    end

    def to_a : Array({K, V})
      res = Array({K, V}).new
      return res unless @root
      dfs = uninitialized Node(K, V) -> Nil
      dfs = ->(node : Node(K, V)) do
        dfs.call(node.left.not_nil!) if node.left
        res << {node.key, node.value}
        dfs.call(node.right.not_nil!) if node.right
        nil
      end
      dfs.call(@root.not_nil!)
      res
    end

    def to_s(io : IO) : Nil
      io << "{" + to_a.map { |key, value| "#{key} => #{value}" }.join(", ") + "}"
    end

    def inspect(io : IO)
      to_s(io)
    end

    def <<(item : {K, V}) : Nil
      upsert(item[0], item[1])
    end

    def [](key : K) : V
      return @default.not_nil! if @root.nil? && !@default.nil?
      raise KeyError.new "Missing key: #{key.inspect}" unless @root
      node = find_node(@root, key)
      return @default.not_nil! if node.not_nil!.key != key && !@default.nil?
      raise KeyError.new "Missing key: #{key.inspect}" if node.not_nil!.key != key
      node.not_nil!.value
    end

    def []?(key : K) : V?
      return @default if @root.nil?
      node = find_node(@root, key)
      return @default if node.not_nil!.key != key
      node.not_nil!.value
    end

    def []=(key : K, value : V) : V
      upsert(key, value)
      value
    end
  end
end
Back to top page