require "./nglib/**"
# require "./nglib/**" OO = (1_i64 << 62) - (1_i64 << 31) module NgLib # 順序付き連想配列です。 # # 平衡二分探索木として [AA木](https://ja.wikipedia.org/wiki/AA%E6%9C%A8) を使用しています。 # 性能は赤黒木の方が良いことが多い気がします。 # # C++の標準ライブラリの `multiset` と違って、$k$ 番目の値が取り出せることなどが魅力的です。 class AATreeMap(K, V) include Enumerable({K, V}) private class Node(K, V) property left : Node(K, V)? property right : Node(K, V)? property parent : Node(K, V)? property key : K property value : V property level : Int32 property size : Int32 def initialize(item : {K, V}) @left = @right = @parent = nil @level = 1 @key = item[0] @value = item[1] @size = 1 end def rotate_left : Node(K, V) right = @right.not_nil! mid = right.left par = @parent if right.parent = par if par.not_nil!.left == self par.not_nil!.left = right else par.not_nil!.right = right end end mid.parent = self if @right = mid right.left = self @parent = right sz = @size @size += (mid ? mid.size : 0) - right.size right.size = sz right end def rotate_right : Node(K, V) left = @left.not_nil! mid = left.right par = @parent if left.not_nil!.parent = par if par.not_nil!.left == self par.not_nil!.left = left else par.not_nil!.right = left end end mid.parent = self if @left = mid left.not_nil!.right = self @parent = left sz = @size @size += (mid ? mid.size : 0) - left.size left.size = sz left end def left_side?(node : Node(K, V)?) : Bool @left == node end def assign(node : Node(K, V)) : V @key = node.key @value = node.value end end @root : Node(K, V)? @default : V? private def find_node(node : Node(K, V)?, key : K) : Node(K, V)? return nil unless node until key == node.not_nil!.key if key < node.not_nil!.key break unless node.not_nil!.left node = node.not_nil!.left else break unless node.not_nil!.right node = node.not_nil!.right end end node end private def skew(node : Node(K, V)?) : Node(K, V)? return nil unless node left = node.not_nil!.left if left && node.not_nil!.level == left.not_nil!.level return node.not_nil!.rotate_right end node end private def split(node : Node(K, V)?) : Node(K, V)? return nil unless node right = node.right if right && right.not_nil!.right && node.level == right.not_nil!.right.not_nil!.level r = node.rotate_left r.level += 1 return r end node end private def upsert(key : K, value : V) : Nil unless @root @root = Node.new({key, value}) return true end node = find_node(@root, key) if node.not_nil!.key == key node.not_nil!.value = value return end new_node = Node.new({key, value}) if key < node.not_nil!.key node.not_nil!.left = new_node else node.not_nil!.right = new_node end new_node.not_nil!.parent = node node = new_node while node node = split(skew(node)) unless node.not_nil!.parent @root = node break end node = node.not_nil!.parent node.not_nil!.size += 1 end end private def begin_node : Node(K, V)? return nil unless @root node = @root while node.not_nil!.left node = node.not_nil!.left end node end private def next_node(node : Node(K, V)) : Node(K, V)? if node.right node = node.right while node.not_nil!.left node = node.not_nil!.left end node else while node par = node.not_nil!.parent if par && par.not_nil!.left_side?(node) return par end node = par end node end end private def level(node : Node(K, V)?) node ? node.level : 0 end def initialize @root = nil @default = nil self end def initialize(@default : V) @root = nil self end def initialize(enumerable : Enumerable({K, V})) @root = nil concat(enumerable) self end def concat(elems) : self elems.each { |elem| self << elem } self end def includes?(key : K, value : V) : Bool node = find_node(@root, key) node.nil? ? false : node.key == key && node.value == value end def clear @root = nil end def empty? : Bool @root.nil? end def at(k : Int) : {K, V} k += size if k < 0 raise IndexError.new unless 0 <= k && k < size node = @root k += 1 loop do left_size = (node.not_nil!.left ? node.not_nil!.left.not_nil!.size : 0) + 1 break if left_size == k if k < left_size node = node.not_nil!.left else node = node.not_nil!.right k -= left_size end end {node.not_nil!.key, node.not_nil!.value} end def at?(k : Int) : {K, V}? k += size if k < 0 return nil unless 0 <= k && k < size at(k) end def key_at(k : Int) : K at(k)[0] end def key_at?(k : Int) : K? t = at?(k); t ? t[0] : nil end def value_at(k : Int) : V at(k)[1] end def value_at?(k : Int) : V? t = at?(k); t ? t[1] : nil end def each_key(& : K ->) each do |key, _| yield key end end def each_value(& : V ->) each do |_, value| yield value end end def each(& : {K, V} ->) node = begin_node while node pr = {node.not_nil!.key, node.not_nil!.value} yield pr node = next_node(node.not_nil!) end end def keys : Array(K) res = Array(K).new each do |key, _| res << key end res end def values : Array(V) res = Array(V).new each do |_, value| res << value end res end def delete_key(key : K) : Bool return false unless @root node = find_node(@root, key) return false unless node.not_nil!.key == key if node.not_nil!.left || node.not_nil!.right child = find_node(node.not_nil!.left ? node.not_nil!.left : node.not_nil!.right, key) node.not_nil!.assign(child.not_nil!) node = child end par = node.not_nil!.parent if par if par.not_nil!.left_side?(node) par.left = nil else par.right = nil end else @root = nil end node = par while node new_level = {level(node.left), level(node.right)}.min + 1 if new_level < node.level node.level = new_level if new_level < level(node.right) node.right.not_nil!.level = new_level end end node.size -= 1 node = skew(node).not_nil! skew(node.right.not_nil!.right) if skew(node.right) node = split(node) split(node.not_nil!.right) unless node.not_nil!.parent @root = node break end node = node.not_nil!.parent end true end def delete_at(k : Int) key = key_at(k) delete_key(key) end def delete_at(k : Int) key = key_at?(k) return if key.nil? delete_key(key) end def has_key?(key : K) : Bool return false unless @root node = find_node(@root, key) node.nil? ? false : node.key == key end def lower_bound_index(key : K) : Int32 node = @root return 0 unless node index = 0 while node if key <= node.not_nil!.key node = node.not_nil!.left else index += (node.not_nil!.left ? node.not_nil!.left.not_nil!.size : 0) + 1 node = node.not_nil!.right end end index end def upper_bound_index(key : K) : Int32 node = @root return 0 unless node index = 0 while node if key < node.not_nil!.key node = node.not_nil!.left else index += (node.not_nil!.left ? node.not_nil!.left.not_nil!.size : 0) + 1 node = node.not_nil!.right end end index end def less_index(key : K) : Int32? index = lower_bound_index(key) index == 0 ? nil : index - 1 end def less_equal_index(key : K) : Int32? index = lower_bound_index(key) key == at?(index) ? index : (index == 0 ? nil : index - 1) end def greater_index(key : K) : Int32? index = upper_bound_index(key) index == size ? nil : index end def greater_equal_index(key : K) : Int32? index = lower_bound_index(key) index == size ? nil : index end def size : Int32 @root ? @root.not_nil!.size : 0 end def to_a : Array({K, V}) res = Array({K, V}).new return res unless @root dfs = uninitialized Node(K, V) -> Nil dfs = ->(node : Node(K, V)) do dfs.call(node.left.not_nil!) if node.left res << {node.key, node.value} dfs.call(node.right.not_nil!) if node.right nil end dfs.call(@root.not_nil!) res end def to_s(io : IO) : Nil io << "{" + to_a.map { |key, value| "#{key} => #{value}" }.join(", ") + "}" end def inspect(io : IO) to_s(io) end def <<(item : {K, V}) : Nil upsert(item[0], item[1]) end def [](key : K) : V return @default.not_nil! if @root.nil? && !@default.nil? raise KeyError.new "Missing key: #{key.inspect}" unless @root node = find_node(@root, key) return @default.not_nil! if node.not_nil!.key != key && !@default.nil? raise KeyError.new "Missing key: #{key.inspect}" if node.not_nil!.key != key node.not_nil!.value end def []?(key : K) : V? return @default if @root.nil? node = find_node(@root, key) return @default if node.not_nil!.key != key node.not_nil!.value end def []=(key : K, value : V) : V upsert(key, value) value end end end module NgLib # 順序付き多重集合です。 # # 平衡二分探索木として [AA木](https://ja.wikipedia.org/wiki/AA%E6%9C%A8) を使用しています。 # 性能は赤黒木の方が良いことが多い気がします。 # # C++の標準ライブラリの `multiset` と違って、$k$ 番目の値が取り出せることなどが魅力的です。 class AATreeMultiset(T) include Enumerable(T) private class Node(T) property left : Node(T)? property right : Node(T)? property parent : Node(T)? property key : T property level : Int32 property size : Int32 def initialize(val : T) @left = @right = @parent = nil @level = 1 @key = val @size = 1 end def rotate_left : Node(T) right = @right.not_nil! mid = right.left par = @parent if right.parent = par if par.not_nil!.left == self par.not_nil!.left = right else par.not_nil!.right = right end end mid.parent = self if @right = mid right.left = self @parent = right sz = @size @size += (mid ? mid.size : 0) - right.size right.size = sz right end def rotate_right : Node(T) left = @left.not_nil! mid = left.right par = @parent if left.not_nil!.parent = par if par.not_nil!.left == self par.not_nil!.left = left else par.not_nil!.right = left end end mid.parent = self if @left = mid left.not_nil!.right = self @parent = left sz = @size @size += (mid ? mid.size : 0) - left.size left.size = sz left end def left_side?(node : Node(T)?) : Bool @left == node end def assign(node : Node(T)) : T @key = node.key end end @root : Node(T)? private def find_node(node : Node(T)?, val : T) : Node(T)? return nil unless node until val == node.not_nil!.key if val < node.not_nil!.key break unless node.not_nil!.left node = node.not_nil!.left else break unless node.not_nil!.right node = node.not_nil!.right end end while node.not_nil!.left && node.not_nil!.left.not_nil!.key == val node = node.not_nil!.left end while node.not_nil!.right && node.not_nil!.right.not_nil!.key == val node = node.not_nil!.right end node end private def find_node2(node : Node(T)?, val : T) : Node(T)? return nil unless node loop do if val <= node.not_nil!.key break unless node.not_nil!.left node = node.not_nil!.left else break unless node.not_nil!.right node = node.not_nil!.right end end node end private def skew(node : Node(T)?) : Node(T)? return nil unless node left = node.not_nil!.left if left && node.not_nil!.level == left.not_nil!.level return node.not_nil!.rotate_right end node end private def split(node : Node(T)?) : Node(T)? return nil unless node right = node.right if right && right.not_nil!.right && node.level == right.not_nil!.right.not_nil!.level r = node.rotate_left r.level += 1 return r end node end private def begin_node : Node(T)? return nil unless @root node = @root while node.not_nil!.left node = node.not_nil!.left end node end private def next_node(node : Node(T)) : Node(T)? if node.right node = node.right while node.not_nil!.left node = node.not_nil!.left end node else while node par = node.not_nil!.parent if par && par.not_nil!.left_side?(node) return par end node = par end node end end private def level(node : Node(T)?) node ? node.level : 0 end def initialize @root = nil self end def initialize(enumerable : Enumerable(T)) @root = nil concat(enumerable) self end def concat(elems) : self elems.each { |elem| self << elem } self end def includes?(val : T) : Bool node = find_node(@root, val) node.nil? ? false : node.key == val end def clear @root = nil end def empty? : Bool @root.nil? end def at(k : Int) : T raise IndexError.new unless 0 <= k && k < size node = @root k += 1 loop do left_size = (node.not_nil!.left ? node.not_nil!.left.not_nil!.size : 0) + 1 break if left_size == k if k < left_size node = node.not_nil!.left else node = node.not_nil!.right k -= left_size end end node.not_nil!.key end def at?(k : Int) : T? return nil unless 0 <= k && k < size at(k) end def each(& : T ->) node = begin_node while node yield node.not_nil!.key node = next_node(node.not_nil!) end end def add(val : T) : Nil add?(val) nil end def add?(val : T) : Bool unless @root @root = Node.new(val) return true end node = find_node2(@root, val) new_node = Node.new(val) if val <= node.not_nil!.key node.not_nil!.left = new_node else node.not_nil!.right = new_node end new_node.not_nil!.parent = node node = new_node while node node = split(skew(node)) unless node.not_nil!.parent @root = node break end node = node.not_nil!.parent node.not_nil!.size += 1 end true end def delete(val : T) : Bool return false unless @root node = find_node(@root, val) return false unless node.not_nil!.key == val if node.not_nil!.left || node.not_nil!.right child = find_node(node.not_nil!.left ? node.not_nil!.left : node.not_nil!.right, val) node.not_nil!.assign(child.not_nil!) node = child end par = node.not_nil!.parent if par if par.not_nil!.left_side?(node) par.left = nil else par.right = nil end else @root = nil end node = par while node new_level = {level(node.left), level(node.right)}.min + 1 if new_level < node.level node.level = new_level if new_level < level(node.right) node.right.not_nil!.level = new_level end end node.size -= 1 node = skew(node).not_nil! skew(node.right.not_nil!.right) if skew(node.right) node = split(node) split(node.not_nil!.right) unless node.not_nil!.parent @root = node break end node = node.not_nil!.parent end true end def delete_at(k : Int) : Bool delete(at(k)) end def delete_at?(k : Int) : Bool val = at?(k) if val delete(val) else false end end def lower_bound_index(val : T) : Int32 node = @root return 0 unless node index = 0 while node if val <= node.not_nil!.key node = node.not_nil!.left else index += (node.not_nil!.left ? node.not_nil!.left.not_nil!.size : 0) + 1 node = node.not_nil!.right end end index end def upper_bound_index(val : T) : Int32 node = @root return 0 unless node index = 0 while node if val < node.not_nil!.key node = node.not_nil!.left else index += (node.not_nil!.left ? node.not_nil!.left.not_nil!.size : 0) + 1 node = node.not_nil!.right end end index end def less_index(val : T) : Int32? index = lower_bound_index(val) index == 0 ? nil : index - 1 end def less_equal_index(val : T) : Int32? index = lower_bound_index(val) val == at?(index) ? index : (index == 0 ? nil : index - 1) end def greater_index(val : T) : Int32? index = upper_bound_index(val) index == size ? nil : index end def greater_equal_index(val : T) : Int32? index = lower_bound_index(val) index == size ? nil : index end def first : T at(0) end def first? : T? at?(0) end def last : T at(size - 1) end def last? : T? at?(size - 1) end def count(val : T) : Int32 upper_bound_index(val) - lower_bound_index(val) end def size : Int32 @root ? @root.not_nil!.size : 0 end def to_a : Array(T) res = Array(T).new return res unless @root dfs = uninitialized Node(T) -> Nil dfs = ->(node : Node(T)) do dfs.call(node.left.not_nil!) if node.left res << node.key dfs.call(node.right.not_nil!) if node.right nil end dfs.call(@root.not_nil!) res end def to_s(io : IO) : Nil io << "{" + to_a.join(", ") + "}" end def inspect(io : IO) : Nil to_s(io) end def ==(other : AATreeSet(T)) : Bool self.to_a == other.to_a end def <<(val : T) : Bool add?(val) end def [](k : Int) : T at(k) end def []?(k : Int) : T | Nil at?(k) end end end module NgLib # 順序付き集合です。 # # 平衡二分探索木として [AA木](https://ja.wikipedia.org/wiki/AA%E6%9C%A8) を使用しています。 # 性能は赤黒木の方が良いことが多い気がします。 # # C++の標準ライブラリの `multiset` と違って、$k$ 番目の値が取り出せることなどが魅力的です。 class AATreeSet(T) include Enumerable(T) private class Node(T) property left : Node(T)? property right : Node(T)? property parent : Node(T)? property key : T property level : Int32 property size : Int32 def initialize(val : T) @left = @right = @parent = nil @level = 1 @key = val @size = 1 end def rotate_left : Node(T) right = @right.not_nil! mid = right.left par = @parent if right.parent = par if par.not_nil!.left == self par.not_nil!.left = right else par.not_nil!.right = right end end mid.parent = self if @right = mid right.left = self @parent = right sz = @size @size += (mid ? mid.size : 0) - right.size right.size = sz right end def rotate_right : Node(T) left = @left.not_nil! mid = left.right par = @parent if left.not_nil!.parent = par if par.not_nil!.left == self par.not_nil!.left = left else par.not_nil!.right = left end end mid.parent = self if @left = mid left.not_nil!.right = self @parent = left sz = @size @size += (mid ? mid.size : 0) - left.size left.size = sz left end def left_side?(node : Node(T)?) : Bool @left == node end def assign(node : Node(T)) : T @key = node.key end end @root : Node(T)? private def find_node(node : Node(T)?, val : T) : Node(T)? return nil unless node until val == node.not_nil!.key if val < node.not_nil!.key break unless node.not_nil!.left node = node.not_nil!.left else break unless node.not_nil!.right node = node.not_nil!.right end end node end private def skew(node : Node(T)?) : Node(T)? return nil unless node left = node.not_nil!.left if left && node.not_nil!.level == left.not_nil!.level return node.not_nil!.rotate_right end node end private def split(node : Node(T)?) : Node(T)? return nil unless node right = node.right if right && right.not_nil!.right && node.level == right.not_nil!.right.not_nil!.level r = node.rotate_left r.level += 1 return r end node end private def begin_node : Node(T)? return nil unless @root node = @root while node.not_nil!.left node = node.not_nil!.left end node end private def next_node(node : Node(T)) : Node(T)? if node.right node = node.right while node.not_nil!.left node = node.not_nil!.left end node else while node par = node.not_nil!.parent if par && par.not_nil!.left_side?(node) return par end node = par end node end end private def level(node : Node(T)?) node ? node.level : 0 end def initialize @root = nil self end def initialize(enumerable : Enumerable(T)) @root = nil concat(enumerable) self end def concat(elems) : self elems.each { |elem| self << elem } self end def includes?(val : T) : Bool node = find_node(@root, val) node.nil? ? false : node.key == val end def clear @root = nil end def empty? : Bool @root.nil? end def at(k : Int) : T raise IndexError.new unless 0 <= k && k < size node = @root k += 1 loop do left_size = (node.not_nil!.left ? node.not_nil!.left.not_nil!.size : 0) + 1 break if left_size == k if k < left_size node = node.not_nil!.left else node = node.not_nil!.right k -= left_size end end node.not_nil!.key end def at?(k : Int) : T? return nil unless 0 <= k && k < size at(k) end def each(& : T ->) node = begin_node while node yield node.not_nil!.key node = next_node(node.not_nil!) end end def add(val : T) : Nil add?(val) nil end def add?(val : T) : Bool unless @root @root = Node.new(val) return true end node = find_node(@root, val) return false if node.not_nil!.key == val # NOT multi new_node = Node.new(val) if val <= node.not_nil!.key node.not_nil!.left = new_node else node.not_nil!.right = new_node end new_node.not_nil!.parent = node node = new_node while node node = split(skew(node)) unless node.not_nil!.parent @root = node break end node = node.not_nil!.parent node.not_nil!.size += 1 end true end def delete(val : T) : Bool return false unless @root node = find_node(@root, val) return false unless node.not_nil!.key == val if node.not_nil!.left || node.not_nil!.right child = find_node(node.not_nil!.left ? node.not_nil!.left : node.not_nil!.right, val) node.not_nil!.assign(child.not_nil!) node = child end par = node.not_nil!.parent if par if par.not_nil!.left_side?(node) par.left = nil else par.right = nil end else @root = nil end node = par while node new_level = {level(node.left), level(node.right)}.min + 1 if new_level < node.level node.level = new_level if new_level < level(node.right) node.right.not_nil!.level = new_level end end node.size -= 1 node = skew(node).not_nil! skew(node.right.not_nil!.right) if skew(node.right) node = split(node) split(node.not_nil!.right) unless node.not_nil!.parent @root = node break end node = node.not_nil!.parent end true end def delete_at(k : Int) : Bool delete(at(k)) end def delete_at?(k : Int) : Bool val = at?(k) if val delete(val) else false end end def lower_bound_index(val : T) : Int32 node = @root return 0 unless node index = 0 while node if val <= node.not_nil!.key node = node.not_nil!.left else index += (node.not_nil!.left ? node.not_nil!.left.not_nil!.size : 0) + 1 node = node.not_nil!.right end end index end def upper_bound_index(val : T) : Int32 node = @root return 0 unless node index = 0 while node if val < node.not_nil!.key node = node.not_nil!.left else index += (node.not_nil!.left ? node.not_nil!.left.not_nil!.size : 0) + 1 node = node.not_nil!.right end end index end def less_index(val : T) : Int32? index = lower_bound_index(val) index == 0 ? nil : index - 1 end def less_equal_index(val : T) : Int32? index = lower_bound_index(val) val == at?(index) ? index : (index == 0 ? nil : index - 1) end def greater_index(val : T) : Int32? index = upper_bound_index(val) index == size ? nil : index end def greater_equal_index(val : T) : Int32? index = lower_bound_index(val) index == size ? nil : index end def first : T at(0) end def first? : T? at?(0) end def last : T at(size - 1) end def last? : T? at?(size - 1) end def count(val : T) : Int32 upper_bound_index(val) - lower_bound_index(val) end def size : Int32 @root.try &.size || 0 end def to_a : Array(T) res = Array(T).new return res unless @root dfs = uninitialized Node(T) -> Nil dfs = ->(node : Node(T)) do left = node.left right = node.right dfs.call(left) if left res << node.key dfs.call(right) if right nil end dfs.call(@root.not_nil!) res end def to_s(io : IO) : Nil io << "{" + to_a.join(", ") + "}" end def inspect(io : IO) : Nil to_s(io) end def ==(other : AATreeSet(T)) : Bool self.to_a == other.to_a end def <<(val : T) : Bool add?(val) end def [](k : Int) : T at(k) end def []?(k : Int) : T | Nil at?(k) end end end module NgLib # 区間作用・$1$ 点取得ができるセグメント木です。 # # 作用素 $f$ は、要素 $x$ と同じ型である必要があります。 class DualSegTree(T) include Indexable(T) include Indexable::Mutable(T) class NotInitializeError < Exception; end getter size : Int32 @n_leaves : Int32 @nodes : Array(T?) # 作用素 $f(x) = x + f$ とした双対セグメント木を作ります。 def self.range_add(values : Array(T)) self.new(values) { |applicator, value| value + applicator } end # :ditto: def self.range_add(size : Int) self.new(size) { |applicator, value| value + applicator } end # 作用素 $f(x) = f$ とした双対セグメント木を作ります。 def self.range_update(values : Array(T)) self.new(values) { |applicator, _value| applicator } end # :ditto: def self.range_update(size : Int) self.new(size) { |applicator, _value| applicator } end # 作用素 $f(x) = \min(f, x)$ とした双対セグメント木を作ります。 def self.range_chmin(values : Array(T)) self.new(values) { |applicator, value| {applicator, value}.min } end # :ditto: def self.range_chmin(size : Int) self.new(size) { |applicator, value| {applicator, value}.min } end # 作用素 $f(x) = \max(f, x)$ とした双対セグメント木を作ります。 def self.range_chmax(values : Array(T)) self.new(values) { |applicator, value| {applicator, value}.max } end # :ditto: def self.range_chmax(size : Int) self.new(size) { |applicator, value| {applicator, value}.max } end # 作用素を $f$ とした、要素数 $n$ の双対セグメント木を作ります。 # # 各要素は単位元を表す `nil` で初期化されます。 # # ``` # seg = NgLib::DualSegTree(Int32).new(5) { |f, x| x + f } # seg # => [nil, nil, nil, nil, nil] # ``` def initialize(size : Int, &@application : T, T -> T) @size = size.to_i32 @n_leaves = 1 while @n_leaves < @size @n_leaves *= 2 end @nodes = Array(T?).new(@n_leaves * 2) { nil } end # 作用素を $f$ とした、$i$ 番目の要素が `values[i]` の双対セグメント木を作ります。 # # ``` # seg = NgLib::DualSegTree(Int32).new([*(1..5)]) { |f, x| x + f } # seg # => [1, 2, 3, 4, 5] # ``` def initialize(values : Array(T), &@application : T, T -> T) @size = values.size @n_leaves = 1 while @n_leaves < @size @n_leaves *= 2 end @nodes = Array(T?).new(@n_leaves * 2) { nil } values.each_with_index do |elem, i| @nodes[i + @n_leaves] = elem end end def unsafe_fetch(index : Int) node_index = index + @n_leaves push(node_index) @nodes[node_index] || raise NotInitializeError.new end def unsafe_put(index : Int, value : T) node_index = index + @n_leaves push(node_index) @nodes[node_index] = value end # `#apply` へのエイリアスです。 # # ``` # seg = NgLib::DualSegTree(Int32).new([*(1..5)]) { |f, x| x + f } # seg # => [1, 2, 3, 4, 5] # seg[0...2] = 10 # seg # => [11, 12, 3, 4, 5] # ``` def []=(range : Range(Int?, Int?), applicator : T) apply(range, applicator) end # `#apply` へのエイリアスです。 # # ``` # seg = NgLib::DualSegTree(Int32).new([*(1..5)]) { |f, x| x + f } # seg # => [1, 2, 3, 4, 5] # seg[0, 2] = 10 # seg # => [11, 12, 3, 4, 5] # ``` def []=(start : Int, count : Int, applicator : T) apply(start, count, applicator) end # `range` の表す区間に `applicator` を作用させます。 # # ``` # seg = NgLib::DualSegTree(Int32).new([*(1..5)]) { |f, x| x + f } # seg # => [1, 2, 3, 4, 5] # seg.apply(0...2, 10) # seg # => [11, 12, 3, 4, 5] # ``` def apply(range : Range(Int?, Int?), applicator : T) l = (range.begin || 0) r = (range.end || @size) + (range.exclusive? || range.end.nil? ? 0 : 1) return if l >= r l += @n_leaves r += @n_leaves push(l >> l.to_i32.trailing_zeros_count) push((r >> r.to_i32.trailing_zeros_count) - 1) while l < r if l.odd? x = @nodes[l] apply_impl(l, applicator, x) l += 1 end if r.odd? r -= 1 x = @nodes[r] apply_impl(r, applicator, x) end l >>= 1 r >>= 1 end self end # `start` 番目から `count` 個までの各要素に `applicator` を作用させます。 # # ``` # seg = NgLib::DualSegTree(Int32).new([*(1..5)]) { |f, x| x + f } # seg # => [1, 2, 3, 4, 5] # seg.apply(0, 2, 10) # seg # => [11, 12, 3, 4, 5] # ``` def apply(start : Int, count : Int, application : T) apply((start...{start + count, @size}.min), application) end # すべての要素に `application` を作用させます。 # # ``` # seg = NgLib::DualSegTree(Int32).new([*(1..5)]) { |f, x| x + f } # seg # => [1, 2, 3, 4, 5] # seg.all_apply(10) # seg # => [11, 12, 13, 14, 15] # ``` def apply_all(applicator : T) apply(.., applicator) end def to_s(io : IO) (0...@size).map { |i| self[i] || 'e' }.to_a.to_s(io) end private def push(node_index : Int) return if node_index.zero? r = 31 - node_index.to_i32.leading_zeros_count r.downto(1) do |i| j = node_index >> i f = @nodes[j] {2*j, 2*j + 1}.each do |child| x = @nodes[child] apply_impl(child, f, x) end @nodes[j] = nil end end @[AlwaysInline] private def apply_impl(node_index : Int, applicator : T?, value : T?) return if applicator.nil? @nodes[node_index] = value.nil? ? applicator : @application.call(applicator, value) end end end # require "./aatree_set.cr" module NgLib # 長さ $n$ の整数列 $a_0, a_1, \cdots, a_{n-1}$ について、 # $[l, r)$ に $x$ が何回現れるかを $O(\log{N})$ で計算するクラスです。 class DynamicRangeFrequency(T) @size : Int32 @map : Hash(Int32, NgLib::AATreeSet(Int32)) @values : Array(T) def initialize(array : Array(T)) @values = array.clone @size = array.size @map = Hash(Int32, NgLib::AATreeSet(Int32)).new array.each_with_index do |a, i| @map[a] = NgLib::AATreeSet(Int32).new unless @map.has_key?(a) @map[a] << i end end def count(range : Range(Int?, Int?), x : T) left = (range.begin || 0).to_i32 right = ((range.end || @size) + (range.exclusive? || range.end.nil? ? 0 : 1)).to_i32 v = @map[x]? || NgLib::AATreeSet(Int32).new lower_bound(v, right) - lower_bound(v, left) end def []=(i : Int, x : T) @map[@values[i]].delete(i.to_i32) @map[x] = NgLib::AATreeSet(Int32).new unless @map.has_key?(x) @map[x] << i.to_i32 @values[i] = x end private def lower_bound(v : NgLib::AATreeSet(Int32), x : Int32) v.greater_equal_index(x) || v.size end end end module NgLib # 動的な数列 $a$ に対して累積和を求めます。 # # 累積和クエリは $O(\log{N})$ で処理することができます。 # # 実装は BIT (Fenwick Tree) です。 # # もし、区間加算 $1$ 点取得がしたい場合は、このライブラリといもす法を組み合わせると良いです。 # 長さ $n$ の数列に対して操作する場合、BIT の長さは $n + 1$ 必要なことに注意してください。 # - $[l, r)$ に $x$ を加算 : `#add(l, x); add(r, -x)` # - $i$ 番目を取得 : `#[..i]` class DynamicRangeSum(T) getter size : Int32 @data : Array(T) # 長さが $n$ で各要素が $0$ の数列 $a$ を構築します。 def initialize(n : Int) @data = Array(T).new(n) { T.zero } @size = @data.size end # 長さが $n$ で各要素が $val$ の数列 $a$ を構築します。 def initialize(n : Int, val : T) @data = Array(T).new(n) { val } @size = @data.size end # 長さが $n$ で $i$ 番目の要素が $elems_i$ の数列 $a$ を構築します。 def initialize(elems : Enumerable(T)) @size = elems.size.to_i32 @data = Array(T).new(@size, T.zero) elems.each_with_index { |x, i| add(i, x) } end # $[l, r)$ 番目までの要素の総和 $\sum_{i=l}^{r-1} a_i$ を $O(\log{N})$ で返します。 # # ``` # array = [1, 1, 2, 3, 5, 8, 13] # csum = DynamicRangeSum(Int32).new(array) # csum.get(0, 5) # => 1 + 1 + 2 + 3 + 5 = 12 # ``` def get(l, r) : T raise IndexError.new("`l` and `r` must be 0 <= l <= r <= self.size (#{l}, #{r})") unless 0 <= l && l <= r && r <= @size sum(r) - sum(l) end # $[l, r)$ 番目までの要素の総和 $\sum_{i=l}^{r-1} a_i$ を $O(\log{N})$ で返します。 # # ``` # array = [1, 1, 2, 3, 5, 8, 13] # csum = DynamicRangeSum(Int32).new(array) # csum[0, 5] # => 1 + 1 + 2 + 3 + 5 = 12 # ``` def [](l, r) : T get(l, r) end # `range` の表す範囲の要素の総和 $\sum_{i \in range} a_i$ を $O(\log{N})$ で返します。 # # ``` # array = [1, 1, 2, 3, 5, 8, 13] # csum = DynamicRangeSum(Int32).new(array) # csum.get(0...5) # => 1 + 1 + 2 + 3 + 5 = 12 # ``` def get(range : Range(Int?, Int?)) : T l = (range.begin || 0) r = (range.end || @size) + (range.exclusive? || range.end.nil? ? 0 : 1) get(l, r) end # `range` の表す範囲の要素の総和 $\sum_{i \in range} a_i$ を $O(\log{N})$ で返します。 # # ``` # array = [1, 1, 2, 3, 5, 8, 13] # csum = DynamicRangeSum(Int32).new(array) # csum[0...5] # => 1 + 1 + 2 + 3 + 5 = 12 # ``` def [](range : Range(Int?, Int?)) : T l = (range.begin || 0) r = (range.end || @size) + (range.exclusive? || range.end.nil? ? 0 : 1) get(l, r) end # $[l, r)$ 番目までの要素の総和 $\sum_{i=l}^{r-1} a_i$ を $O(\log{N})$ で返します。 # # $0 \leq l \leq r \leq n$ を満たさないとき、`nil` を返します。 # # ``` # array = [1, 1, 2, 3, 5, 8, 13] # csum = DynamicRangeSum(Int32).new(array) # csum.get?(0, 5) # => 1 + 1 + 2 + 3 + 5 = 12 # csum.get?(7, 3) # => nil # ``` def get?(l, r) : T? return nil unless 0 <= l && l <= r && r <= @size get(l, r) end # $[l, r)$ 番目までの要素の総和 $\sum_{i=l}^{r-1} a_i$ を $O(\log{N})$ で返します。 # # $0 \leq l \leq r \leq n$ を満たさないとき、`nil` を返します。 # # ``` # array = [1, 1, 2, 3, 5, 8, 13] # csum = DynamicRangeSum(Int32).new(array) # csum[0, 5]? # => 1 + 1 + 2 + 3 + 5 = 12 # csum[7, 3]? # => nil # ``` def []?(l, r) : T? get?(l, r) end # `range` の表す範囲の要素の総和 $\sum_{i \in range} a_i$ を $O(\log{N})$ で返します。 # # $0 \leq l \leq r \leq n$ を満たさないとき、`nil` を返します。 # # ``` # array = [1, 1, 2, 3, 5, 8, 13] # csum = DynamicRangeSum(Int32).new(array) # csum.get?(0...5) # => 1 + 1 + 2 + 3 + 5 = 12 # csum.get?(7...3) # => nil # ``` def get?(range : Range(Int?, Int?)) : T l = (range.begin || 0) r = (range.end || @size) + (range.exclusive? || range.end.nil? ? 0 : 1) get?(l, r) end # `range` の表す範囲の要素の総和 $\sum_{i \in range} a_i$ を $O(\log{N})$ で返します。 # # $0 \leq l \leq r \leq n$ を満たさないとき、`nil` を返します。 # # ``` # array = [1, 1, 2, 3, 5, 8, 13] # csum = DynamicRangeSum(Int32).new(array) # csum[0...5]? # => 1 + 1 + 2 + 3 + 5 = 12 # csum[7...3]? # => nil # ``` def []?(range : Range(Int?, Int?)) : T l = (range.begin || 0) r = (range.end || @size) + (range.exclusive? || range.end.nil? ? 0 : 1) get?(l, r) end # $a_i$ を $O(\log{N})$ で返します。 # # ``` # array = [1, 1, 2, 3, 5, 8, 13] # csum = DynamicRangeSum(Int32).new(array) # csum.get(5) # => 8 # ``` def get(i) : T get(i, i + 1) end # $a_i$ を $O(\log{N})$ で返します。 # # ``` # array = [1, 1, 2, 3, 5, 8, 13] # csum = DynamicRangeSum(Int32).new(array) # csum[5] # => 8 # ``` def [](i) : T get(i, i + 1) end # $a_i$ を $O(\log{N})$ で返します。 # # $0 \leq i \lt n$ を満たさないとき、`nil` を返します。 # # ``` # array = [1, 1, 2, 3, 5, 8, 13] # csum = DynamicRangeSum(Int32).new(array) # csum.get?(5) # => 8 # csum.get?(10)? # => nil # ``` def get?(i) : T? get?(i, i + 1) end # $a_i$ を $O(\log{N})$ で返します。 # # $0 \leq i \lt n$ を満たさないとき、`nil` を返します。 # # ``` # array = [1, 1, 2, 3, 5, 8, 13] # csum = DynamicRangeSum(Int32).new(array) # csum[5] # => 8 # csum[10]? # => nil # ``` def []?(i) : T? get?(i, i + 1) end # $a_i$ の値を $x$ に更新します。 # # ``` # array = [1, 1, 2, 3, 5, 8, 13] # csum = DynamicRangeSum(Int32).new(array) # csum.get?(0...5) # => 1 + 1 + 2 + 3 + 5 = 12 # csum[0] = 100 # csum.get?(0...5) # => 100 + 1 + 2 + 3 + 5 = 111 # ``` def []=(i : Int, x : T) : T add(i, x - get(i)) end # $a_i$ の値を $x$ に更新します。 # # ``` # array = [1, 1, 2, 3, 5, 8, 13] # csum = DynamicRangeSum(Int32).new(array) # csum.get?(0...5) # => 1 + 1 + 2 + 3 + 5 = 12 # csum.set(0, 100) # csum.get?(0...5) # => 100 + 1 + 2 + 3 + 5 = 111 # ``` def set(i : Int, x : T) self[i] = x end # $a_i$ に $x$ を加算します。 # # ``` # array = [1, 1, 2, 3, 5, 8, 13] # csum = DynamicRangeSum(Int32).new(array) # csum.get?(0...5) # => 1 + 1 + 2 + 3 + 5 = 12 # csum.add(0, 99) # csum.get?(0...5) # => 100 + 1 + 2 + 3 + 5 = 111 # ``` def add(i : Int, x : T) i += 1 while i <= @size @data[i - 1] += x i += i & -i end x end private def sum(r : Int) : T s = T.zero while r > 0 s += @data[r - 1] r -= r & -r end s end end end module NgLib # 変更されうる二次元配列 $A$ に対して、累積和 $\sum_{i=l_i}^{r_i-1} \sum_{j=l_j}^{r_j-1} A_i$ を計算します。 class DynamicRectangleSum(T) getter height : Int32 getter width : Int32 getter csum : Array(Array(T)) def initialize(h : Int, w : Int) @height = h.to_i32 @width = w.to_i32 @csum = Array.new(h + 1) { Array.new(w + 1, T.zero) } end def initialize(grid : Array(Array(T))) @height = grid.size @width = (grid[0]? || [] of T).size @csum = Array.new(@height + 1) { Array.new(@width + 1) { T.zero } } @height.times do |i| @width.times do |j| add(i, j, grid[i][j]) end end end # (y, x) の要素に val を足します。 # # 添字は 0-index です。 # # ``` # csum = DynamicRectangleSum.new(a) # csum.add(y, x, val) # => val # ``` def add(y : Int, x : Int, val : T) : T raise IndexError.new("y = #{y} が配列外参照しています。 (@height = #{@height}") if y < 0 || y >= @height raise IndexError.new("x = #{x} が配列外参照しています。 (@height = #{@width}") if x < 0 || x >= @width i = y + 1 while i <= @height j = x + 1 while j <= @width @csum[i][j] += val j += (j & -j) end i += (i & -i) end val end # (y, x) の要素に val を足します。 # # 添字は 0-index です。 # # 加算に成功した場合 `true` を返します。 # # ``` # csum = DynamicRectangleSum.new(a) # csum.add?(y, x, x) # => true # ``` def add?(y : Int, x : Int, val : T) : Bool return false if y < 0 || y >= @height return false if x < 0 || x >= @width i = y + 1 while i <= @height j = x + 1 while j <= @width @csum[i][j] += val j += (j & -j) end i += (i & -i) end true end # 累積和を返します。 # # [y_begin, y_end), [x_begin, x_end) で指定します。 # # NOTE: このAPIは非推奨です。Rangeで指定することが推奨されます。 def get(y_begin : Int, y_end : Int, x_begin : Int, x_end : Int) : T raise IndexError.new("`y_begin` must be less than or equal to `y_end` (#{y_begin}, #{y_end})") unless y_begin <= y_end raise IndexError.new("`x_begin` must be less than or equal to `x_end` (#{x_begin}, #{x_end})") unless x_begin <= x_end query(y_end, x_end) - query(y_end, x_begin) - query(y_begin, x_end) + query(y_begin, x_begin) end # 累積和を返します。 # # [y_begin, y_end), [x_begin, x_end) で指定します。 # # 範囲内に要素が存在しない場合 nil を返します。 # # NOTE: このAPIは非推奨です。Rangeで指定することが推奨されます。 def get?(y_begin : Int, y_end : Int, x_begin : Int, x_end : Int) : T? return nil unless y_begin <= y_end return nil unless x_begin <= x_end query(y_end, x_end) - query(y_end, x_begin) - query(y_begin, x_end) + query(y_begin, x_end) end # 累積和を取得します。 # # Range(y_begin, y_end), Range(x_begin, x_end) で指定します。 # # ``` # csum = DynamicRectangleSum.new(a) # csum.get(0...h, j..j + 2) # => 28 # ``` def get(y_range : Range(Int?, Int?), x_range : Range(Int?, Int?)) : T y_begin = (y_range.begin || 0) y_end = (y_range.end || @height) + (y_range.exclusive? || y_range.end.nil? ? 0 : 1) x_begin = (x_range.begin || 0) x_end = (x_range.end || @height) + (x_range.exclusive? || x_range.end.nil? ? 0 : 1) get(y_begin, y_end, x_begin, x_end) end # 累積和を返します。 # # [y_begin, y_end), [x_begin, x_end) で指定します。 # # 範囲内に要素が存在しない場合 nil を返します。 # # ``` # csum = DynamicRectangleSum.new(a) # csum.get?(0...h, j..j + 2) # => 28 # csum.get?(0...100*h, j..j + 2) # => nil # ``` def get?(y_range : Range(Int?, Int?), x_range : Range(Int?, Int?)) : T? y_begin = (y_range.begin || 0) y_end = (y_range.end || @height) + (y_range.exclusive? || y_range.end.nil? ? 0 : 1) x_begin = (x_range.begin || 0) x_end = (x_range.end || @height) + (x_range.exclusive? || x_range.end.nil? ? 0 : 1) get?(y_begin, y_end, x_begin, x_end) end def [](y_range : Range(Int?, Int?), x_range : Range(Int?, Int?)) : T get(y_range, x_range) end def []?(y_range : Range(Int?, Int?), x_range : Range(Int?, Int?)) : T? get?(y_range, x_range) end def []=(i : Int, j : Int, val : T) add(i, j, val - get(i..i, j..j)) end private def query(h : Int, w : Int) : T acc = T.zero i = h while i > 0 j = w while j > 0 acc += @csum[i][j] j -= (j & -j) end i -= (i & -i) end acc end end end # require "./aatree_set" # require "./aatree_multiset" module NgLib class MexSet(T) @set : AATreeSet({T, T}) @mset : AATreeMultiset(T) def initialize @set = AATreeSet({T, T}).new([ {T::MIN, T::MIN}, {T::MAX, T::MAX}, ]) @mset = AATreeMultiset(T).new end # 下限値 `inf` で、上限値が `sup` の `MexSet` を構築します。 # # NOTE: 非推奨の API です。mex を求めるときに `inf` のみ指定する方法を推奨します。 # # ``` # # 非負整数に対する MexSet # set = MexSet(Int64).new(0_i64, Int64::MAX) # ``` def initialize(inf : T, sup : T) @set = AATreeSet({T, T}).new([ {inf, inf}, {sup, sup}, ]) @mset = AATreeMultiset(T).new end # 集合に $x$ が含まれるなら `true` を返します。 def includes?(x : T) i = @set.greater_equal_index({x.succ, x.succ}).not_nil! - 1 l, u = @set[i] l <= x && x <= u end # 集合に $x$ を追加します。 def add(x : T) ni = @set.greater_equal_index({x.succ, x.succ}).not_nil! nl, nu = @set[ni] i = ni - 1 l, u = @set[i] if l <= x && x <= u @mset << x return self end if u == x - 1 if nl == x + 1 @set.delete({l, u}) @set.delete({nl, nu}) @set << {l, nu} else @set.delete({l, u}) @set << {l, x} end else if nl == x + 1 @set.delete({nl, nu}) @set << {x, nu} else @set << {x, x} end end self end # 集合に $x$ を追加します。 # # mex の値に変更があったとき `true` を返します。 def add?(x : T) ni = @set.greater_equal_index({x.succ, x.succ}).not_nil! nl, nu = @set[ni] i = ni - 1 l, u = @set[i] if l <= x && x <= u @mset << x return false end if u == x - 1 if nl == x + 1 @set.delete({l, u}) @set.delete({nl, nu}) @s << {l, nu} else @set.delete({l, u}) @set << {l, x} end else if nl == x + 1 @set.delete({nl, nu}) @set << {x, nu} else @set << {x, x} end end true end # 集合から $x$ を削除します。 def delete(x : T) i0 = @mset.greater_equal_index(x) if !i0.nil? && @mset[i0] == x @mset.delete_at(i0.not_nil!) return self end i = @set.greater_equal_index({x + 1, x + 1}).not_nil! - 1 l, u = @set[i] if x < l || u < x return self end @set.delete_at(i) if x == l @set << {l + 1, u} if l < u elsif x == u @set << {l, u - 1} if l < u else @set << {l, x - 1} @set << {x + 1, u} end self end # 集合から $x$ を削除します。 # # 実際に値が削除された場合 `true` を返します。 def delete?(x : T) i0 = @mset.index(x) unless i0.nil? @mset.delete_at(i0.not_nil!) return true end i = @set.greater_equal_index({x + 1, x + 1}).not_nil! - 1 l, u = @set[i] if x < l || u < x return false end @set.delete_at(i) if x == l @set << {l + 1, u} if l < u elsif x == u @set << {l, u - 1} if l < u else @set << {l, x - 1} @set << {x + 1, u} end true end # `inf` を下限値として $\mathrm{mex}$ を求めます。 # # 非負整数に対する $\mathrm{mex}$ はデフォルト値の T.zero を使用すれば良いです。 def mex(inf : T = T.zero) i = @set.greater_equal_index({inf + 1, inf + 1}).not_nil! - 1 l, u = @set[i] if l <= inf && inf <= u return u + 1 end inf end # `add` へのエイリアスです。 def <<(x : T) add(x) end end end module NgLib # 重み付き DSU (Union-Find) などと呼ばれるデータ構造です。 # # Abel 群を載せることができます。すなわち、次のメソッドが実装されているオブジェクトが載ります。 # - `#zero` # - `+` # - `-` class PotentializedDisjointSet(Abel) @n : Int32 @parent_or_size : Array(Int32) @potentials : Array(Abel) # 0 頂点 0 辺の無向グラフを作ります。 # # ``` # ut = PotentializedDisjointSet(Abel).new # ``` def initialize @n = 0 @parent_or_size = Array(Int32).new @potentials = Array(Abel).new end # n 頂点 0 辺の無向グラフを作ります。 # # ``` # n = int # ut = PotentializedDisjointSet(Abel).new(n) # ``` def initialize(size : Int) @n = size.to_i32 @parent_or_size = [-1] * size @potentials = Array.new(size) { Abel.zero } end # w[high] - w[low] = diff となるように、 # 頂点 low と頂点 high を接続します。 # # (w[low] + diff = w[high] と捉えても良いです。) # # 接続後のリーダーを返します。 # # diff は符号付きであることに注意してください。 # また、low と high がすでに接続されている場合の動作は未定義です。 # # ``` # n = int # ut = PotentializedDisjointSet(Abel).new(n) # ut.unite(low: a, high: b, diff: w) # => leader(a) or leader(b) # ``` def unite(low : Int, high : Int, diff : Abel) : Int64 diff += weight(low) - weight(high) x = leader(low) y = leader(high) return x.to_i64 if x == y if -@parent_or_size[x] < -@parent_or_size[y] x, y = y, x diff = -diff end @parent_or_size[x] += @parent_or_size[y] @parent_or_size[y] = x.to_i32 @potentials[y] = diff x.to_i64 end # 頂点 a と頂点 b が同じ連結成分に属しているなら `true` を返します。 # # ``` # n = int # ut = PotentializedDisjointSet(Abel).new(n) # ut.equiv?(u, v) # => true # ``` def equiv?(a : Int, b : Int) : Bool leader(a) == leader(b) end # 頂点 a の属する連結成分のリーダーを返します。 # # ``` # n = int # ut = PotentializedDisjointSet(Abel).new(n) # ut.unite(2, 3, 0) # ut.leader(0) # => 0 # ut.leader(3) # => 2 (3 の可能性もある) # ``` def leader(a : Int) : Int64 return a.to_i64 if @parent_or_size[a] < 0 l = leader(@parent_or_size[a]).to_i32 @potentials[a] += @potentials[@parent_or_size[a]] @parent_or_size[a] = l @parent_or_size[a].to_i64 end # w[high] - w[low] を返します。 # # low と high が同じ連結成分に属していない場合 Abel.zero を返します。 # # ``` # ut = PotentializedDisjointSet(Abel).new(size: 10) # ut.unite(2, 1, 1) # ut.unite(2, 3, 5) # ut.unite(3, 4, 2) # ut.diff(1, 2) # => -1 # ut.diff(2, 1) # => 1 # ut.diff(2, 3) # => 5 # ut.diff(2, 4) # => 7 # ut.diff(0, 9) # => Abel.zero # ``` def diff(low : Int, high : Int) : Abel weight(high) - weight(low) end # w[high] - w[low] を返します。 # # low と high が同じ連結成分に属していない場合 nil を返します。 # # ``` # ut = PotentializedDisjointSet(Abel).new(size: 10) # ut.unite(2, 1, 1) # ut.unite(2, 3, 5) # ut.unite(3, 4, 2) # ut.diff(1, 2) # => -1 # ut.diff(2, 1) # => 1 # ut.diff(2, 3) # => 5 # ut.diff(2, 4) # => 7 # ut.diff(0, 9) # => nil # ``` def diff?(low : Int, high : Int) : Abel? return nil unless equiv?(low, high) weight(high) - weight(low) end # 頂点 a が属する連結成分の大きさを返します。 def size(a : Int) : Int64 -@parent_or_size[leader(a)].to_i64 end # グラフを連結成分に分け、その情報を返します。 # # 返り値は「「一つの連結成分の頂点番号のリスト」のリスト」です。 # (内側外側限らず)Array 内でどの順番で頂点が格納されているかは未定義です。 def groups : Array(Array(Int64)) | Nil leader_buf = Array(Int64).new(@n, 0_i64) group_size = Array(Int64).new(@n, 0_i64) @n.times do |i| leader_buf[i] = leader(i) group_size[leader_buf[i]] += 1 end res = Array.new(@n) { Array(Int64).new } @n.times do |i| res[leader_buf[i]] << i.to_i64 end res.delete([] of Int64) res end private def weight(a : Int) : Abel leader(a) @potentials[a] end end end # require "./aatree_multiset" module NgLib # 昇順(降順) $k$ 個の総和を効率良く求めるためのデータ構造です。 # # 値の追加、削除、$k$ の変更ができます。 class PrioritySum(T) getter k : Int32 getter sum : T @tag : Symbol @mset : NgLib::AATreeMultiset(T) delegate size, to: @mset delegate empty?, to: @mset # 下位 $k$ 要素の総和を求めるためのデータ構造を構築します。 def self.min(k : Int, initial : T = T.zero) self.new(:min, k, initial) end # 上位 $k$ 要素の総和を求めるためのデータ構造を構築します。 def self.max(k : Int, initial : T = T.zero) self.new(:max, k, initial) end def initialize(@tag : Symbol, k : Int, initial : T = T.zero) @k = k.to_i32 @sum = initial @mset = NgLib::AATreeMultiset(T).new end # 要素 $x$ をデータ構造に追加します。 # # 計算量は $O(\log{n})$ です。 def add(x : T) if size < @k @sum += x else kth = @mset.at(kth_index(@k - 1)) @sum = @sum - kth + x if cmp(x, kth) end @mset << x end # Alias for `#add` def <<(x : T) add(x) end # 要素 $x$ をデータ構造から削除します。 # # 計算量は $O(\log{n})$ です。 def delete(x : T) if size <= @k @sum -= x @mset.delete(x) else kth = @mset.at(kth_index(@k)) @sum -= x if cmp(x, kth) @mset.delete(x) kth2 = @mset.at(kth_index(@k - 1)) @sum += kth2 if cmp(x, kth) end end # $k$ の値を変更します。 # # 計算量は $\Delta k \log{\Delta k}$ def k=(k : Int) if @k < k (k - @k).times do |i| break if k + i >= size @sum += @mset.at(kth_index(k + i)) end elsif @k > k (@k - k).times do |i| next if @k - i - 1 >= size break if @k - i < 1 @sum -= @mset.at(kth_index(@k - i - 1)) end end @k = k.to_i32 end private def kth_index(k : Int) case @tag when :max @mset.size - k - 1 when :min k else raise IndexError.new end end private def cmp(a : T, b : T) case @tag when :max a > b when :min a < b end end end end module NgLib # データ列 $a$ に対して、$\min(a_i, a_{i + 1}, \dots, a_{i + \mathrm{length} - 1})$ を、 # 前計算 $O(N)$ クエリが $O(1)$ 求めるためのデータ構造です。 # # セグメント木やSparseTableと異なり、区間長が固定の範囲でしかクエリに答えられませんが高速です。 # # `query(i)` で $[i, i + \mathrm{length})$ が配列の範囲を超えたとき、$[i, \mathrm{a.size})$ だと思って計算します。 # # もし、`query(i)` で $[i - \mathrm{length}, i)$ が求まってほしい場合は、`a = ([eins] * length) + a` としておけば良いです。 # 範囲外の場合は $[0, i)$ だと思って計算されます。 # # なお、$[0, 0)$ の場合は単位元が返ります。 # # もし、query(i) で $[i - \mathrm{length}, i]$ が求まってほしい場合は、`a = ([eins] * (length - 1)) + a` としておけば良いです。 # 範囲外の場合は $[0, i]$ だと思って計算されます。 class SlideMinmax(T) @length : Int32 @data : Array(T) # データ列 $a$ に対して、$\min(a_i, a_{i + 1}, \dots, a_{i + \mathrm{length} - 1})$ を求めるためのデータ構造を構築します。 # # ``` # rmq = SlideMinmax(Int32).min([2, 7, 3, 4, 6], 3) # ``` def self.min(a : Array(T), length : Int) new(a, length, T::MAX) { |lhs, rhs| lhs <= rhs } end # データ列 $a$ に対して、$\max(a_i, a_{i + 1}, \dots, a_{i + \mathrm{length} - 1})$ を求めるためのデータ構造を構築します。 # # ``` # rmq = SlideMinmax(Int32).max([2, 7, 3, 4, 6], 3) # ``` def self.max(a : Array(T), length : Int) new(a, length, T::MIN) { |lhs, rhs| lhs >= rhs } end # データ列 $a$ に対して、$cmp(a_i, a_{i + 1}, \dots, a_{i + \mathrm{length} - 1})$ を求めるためのデータ構造を構築します。 # # `eins` には `@cmp` に対する単位元を渡してください。 # # この API は非推奨です。`self.min` または `self.max` を使用してください。 # # ``` # rmq = SlideMinmax(Int32).new([2, 7, 3, 4, 6], 3) { |a, b| a <= b } # => min query # ``` def initialize(a : Array(T), length : Int, eins : T, &@cmp : (T, T) -> Bool) a.concat([eins] * (length - 1)) @length = length.to_i32 @data = Array(T).new(a.size) tops = Deque(Int32).new a.each_with_index do |e, i| while !tops.empty? && @cmp.call(e, a[tops.last]) tops.pop end tops << i if i - length + 1 >= 0 @data << a[tops.first] if tops.first == i - length + 1 tops.shift end end end end # $[i, i + \mathrm{length})$ の範囲の総積 $cmp(a_i, a_{i + 1}, \dots, a_{i + \mathrm{length} - 1})$ を求めます。 # # $i + \mathrm{length}$ が $a$ のサイズを超える場合は、$[i, \mathrm{a.size})$ で求めます。 # # ``` # rmq = SlideMinmax(Int32).min([2, 7, 3, 4, 6], 3) # rmq.query(0) # => 2 # rmq.query(1) # => 3 # rmq.query(2) # => 3 # rmq.query(3) # => 4 # rmq.query(4) # => 6 # ``` def query(i : Int) : T @data[i] end # $[i, i + \mathrm{length})$ の範囲の総積 $cmp(a_i, a_{i + 1}, \dots, a_{i + \mathrm{length} - 1})$ を求めます。 # # $i + \mathrm{length}$ が $a$ のサイズを超える場合は、$[i, \mathrm{a.size})$ で求めます。 # # 配列外参照の場合は `nil` を返します。 # # ``` # rmq = SlideMinmax(Int32).min([2, 7, 3, 4, 6], 3) # rmq.query?(0) # => 2 # rmq.query?(1) # => 3 # rmq.query?(2) # => 3 # rmq.query?(3) # => 4 # rmq.query?(4) # => 6 # rmq.query?(100) # => nil # ``` def query?(i : Int) : T? @data[i]? end end end module NgLib # 不変な数列 $A$ について、以下の条件を満たす演算を、区間クエリとして処理します。 # - 結合則 : $(x \oplus y) \oplus z = x \oplus (y \oplus z)$ # - 冪等性 : $x \oplus x = x$ # # 前計算は $O(N \log{N})$ かかりますが、区間クエリには $O(1)$ で答えられます。 class SparseTable(T) getter size : Int32 @data : Array(T) @table : Array(Array(T)) @lookup : Array(Int32) @op : (T, T) -> T # $\oplus = \max$ としてデータ構造を構築します。 def self.max(elems : Enumerable(T)) new elems, ->(x : T, y : T) { x > y ? x : y } end # $\oplus = \min$ としてデータ構造を構築します。 def self.min(elems : Enumerable(T)) new elems, ->(x : T, y : T) { x < y ? x : y } end # $\oplus = \mathrm{bitwise-or}$ としてデータ構造を構築します。 def self.bitwise_or(elems : Enumerable(T)) new elems, ->(x : T, y : T) { x | y } end # $\oplus = \mathrm{bitwise-and}$ としてデータ構造を構築します。 def self.bitwise_and(elems : Enumerable(T)) new elems, ->(x : T, y : T) { x & y } end # $\oplus = \mathrm{gcd}$ としてデータ構造を構築します。 def self.gcd(elems : Enumerable(T)) new elems, ->(x : T, y : T) { x.gcd(y) } end # $\oplus = op$ としてデータ構造を構築します。 def initialize(elems : Enumerable(T), @op : (T, T) -> T) @size = elems.size @data = Array(T).new log = (0..).index! { |k| (1 << k) > @size } @table = Array.new(log) { Array(T).new(1 << log, T.zero) } elems.each_with_index { |e, i| @table[0][i] = e; @data << e } (1...log).each do |i| j = 0 while j + (1 << i) <= (1 << log) @table[i][j] = @op.call(@table[i - 1][j], @table[i - 1][j + (1 << (i - 1))]) j += 1 end end @lookup = [0] * (@size + 1) (2..@size).each do |i| @lookup[i] = @lookup[i >> 1] + 1 end end # `range` の表す範囲の要素の総積 $\bigoplus_{i \in range} a_i$ を返します。 # # ``` # rmq = SparseTable(Int32).min([2, 7, 1, 8, 1]) # rmq.prod(0...3) # => 1 # ``` def prod(range : Range(Int?, Int?)) l = (range.begin || 0) r = if range.end.nil? @size else range.end.not_nil! + (range.exclusive? ? 0 : 1) end b = @lookup[r - l] @op.call(@table[b][l], @table[b][r - (1 << b)]) end # `range` の表す範囲の要素の総積 $\bigoplus_{i \in range} a_i$ を返します。 # # $0 \leq l \leq r \leq n$ を満たさないとき、`nil` を返します。 # # ``` # rmq = SparseTable(Int32).min([2, 7, 1, 8, 1]) # rmq.prod(0...3) # => 1 # ``` def prod?(range : Range(Int?, Int?)) l = (range.begin || 0) r = if range.end.nil? @size else range.end.not_nil! + (range.exclusive? ? 0 : 1) end return nil unless 0 <= l && l <= r && r <= @size prod(range) end # $a_i$ を返します。 def [](i) @data[i] end # $a_i$ を返します。 # # 添字が範囲外のとき、`nil` を返します。 def []?(i) @data[i]? end # `prod` へのエイリアスです。 def [](range : Range(Int?, Int?)) prod(range) end # `prod?` へのエイリアスです。 def []?(range : Range(Int?, Int?)) prod?(range) end end end module NgLib # 長さ $n$ の整数列 $a_0, a_1, \cdots, a_{n-1}$ について、 # $[l, r)$ に $x$ が何回現れるかを $O(\log{N})$ で計算するクラスです。 class StaticRangeFrequency(T) @size : Int32 @map : Hash(T, Array(Int32)) def initialize(array : Array(T)) @size = array.size @map = Hash(T, Array(Int32)).new array.each_with_index do |elem, i| @map[elem] = [] of Int32 unless @map.has_key?(elem) @map[elem] << i end end def query(range : Range(Int?, Int?), x : T) left = (range.begin || 0).to_i32 right = ((range.end || @size) + (range.exclusive? || range.end.nil? ? 0 : 1)).to_i32 v = @map.fetch(x, [] of Int32) lower_bound(v, right) - lower_bound(v, left) end private def lower_bound(v : Array(Int32), x : Int32) v.bsearch_index { |elem| elem >= x } || v.size end end end module NgLib # 不変な数列 $A$ に対して、$\sum_{i=l}^{r-1} A_i$ を前計算 $O(N)$ クエリ $O(1)$ で求めます。 class StaticRangeSum(T) getter size : Int64 getter csum : Array(T) # 配列 `array` に対して累積和を構築します。 # # ``` # array = [1, 1, 2, 3, 5, 8, 13] # csum = StaticRangeSum(Int32).new(array) # ``` def initialize(array : Array(T)) @size = array.size.to_i64 @csum = Array.new(@size + 1, T.zero) @size.times { |i| @csum[i + 1] = @csum[i] + array[i] } end # self[0...r] >= x を満たす最小の self[0...r] def lower_bound(x) ac = @size + 1 wa = 0 while ac - wa > 1 wj = ac + (wa - ac) // 2 if self[0...wj] >= x ac = wj else wa = wj end end ac == @size + 1 ? nil : self[0...ac] end # self[0...r] >= x を満たす最小の r を返します。 def lower_bound_index(x) ac = @size + 1 wa = 0 while ac - wa > 1 wj = ac + (wa - ac) // 2 if self[0...wj] >= x ac = wj else wa = wj end end ac == @size + 1 ? nil : ac end # self[0...r] > x を満たす最小の self[0...r] def upper_bound(x) ac = @size + 1 wa = 0 while ac - wa > 1 wj = ac + (wa - ac) // 2 if self[0...wj] > x ac = wj else wa = wj end end ac == @size + 1 ? nil : self[0...ac] end # self[0...r] > x を満たす最小の r を返します。 def upper_bound_index(x) ac = @size + 1 wa = 0 while ac - wa > 1 wj = ac + (wa - ac) // 2 if self[0...wj] > x ac = wj else wa = wj end end ac == @size + 1 ? nil : ac end # $[l, r)$ 番目までの要素の総和 $\sum_{i=l}^{r-1} a_i$ を $O(1)$ で返します。 # # ``` # array = [1, 1, 2, 3, 5, 8, 13] # csum = StaticRangeSum(Int32).new(array) # csum.get(0...5) # => 1 + 1 + 2 + 3 + 5 = 12 # ``` def get(l, r) : T raise IndexError.new("`l` must be less than or equal to `r` (#{l}, #{r})") unless l <= r @csum[r] - @csum[l] end # :ditto: def get(range : Range(Int?, Int?)) : T l = (range.begin || 0) r = (range.end || @size) + (range.exclusive? || range.end.nil? ? 0 : 1) get(l, r) end # $[l, r)$ 番目までの要素の総和 $\sum_{i=l}^{r-1} a_i$ を $O(1)$ で返します。 # # $l \leq r$ を満たさないとき、`nil` を返します。 # # ``` # array = [1, 1, 2, 3, 5, 8, 13] # csum = StaticRangeSum(Int32).new(array) # csum.get?(0...5) # => 1 + 1 + 2 + 3 + 5 = 12 # csum.get?(7...5) # => nil # ``` def get?(l, r) : T? return nil unless l <= r get(l, r) end # :ditto: def get?(range : Range(Int?, Int?)) : T? l = (range.begin || 0) r = (range.end || @size) + (range.exclusive? || range.end.nil? ? 0 : 1) get?(l, r) end # $\sum_{i=1}^{r - 1} a_i - \sum_{i=1}^{l} a_i$ を $O(1)$ で返します。 def get!(l, r) : T @csum[r] - @csum[l] end # :ditto: def get!(range : Range(Int?, Int?)) : T l = (range.begin || 0) r = (range.end || @size) + (range.exclusive? || range.end.nil? ? 0 : 1) get!(l, r) end # `get(l : Int, r : Int)` へのエイリアスです。 def [](l, r) : T get(l, r) end # `get(range : Range(Int?, Int?))` へのエイリアスです。 def [](range : Range(Int?, Int?)) : T get(range) end # `get(l : Int, r : Int)` へのエイリアスです。 def []?(l, r) : T? get?(l, r) end # `get?(range : Range(Int?, Int?))` へのエイリアスです。 def []?(range : Range(Int?, Int?)) : T? get?(range) end end end module NgLib # 不変な二次元配列 $A$ に対して、$\sum_{i=l_i}^{r_i-1} \sum_{j=l_j}^{r_j-1} A_i$ を前計算 $O(N)$ クエリ $O(1)$ で求めます。 class StaticRectangleSum(T) getter height : Int32 getter width : Int32 getter csum : Array(Array(T)) def initialize(grid : Array(Array(T))) @height = grid.size @width = (grid[0]? || [] of T).size @csum = Array.new(@height + 1) { Array.new(@width + 1) { T.zero } } @height.times do |i| @width.times do |j| @csum[i + 1][j + 1] = @csum[i][j + 1] + @csum[i + 1][j] - @csum[i][j] + grid[i][j] end end end # 累積和を返します。 # # [y_begin, y_end), [x_begin, x_end) で指定します。 # # NOTE: このAPIは非推奨です。Rangeで指定することが推奨されます。 def get(y_begin : Int, y_end : Int, x_begin : Int, x_end : Int) : T raise IndexError.new("`y_begin` must be less than or equal to `y_end` (#{y_begin}, #{y_end})") unless y_begin <= y_end raise IndexError.new("`x_begin` must be less than or equal to `x_end` (#{x_begin}, #{x_end})") unless x_begin <= x_end @csum[y_end][x_end] - @csum[y_begin][x_end] - @csum[y_end][x_begin] + @csum[y_begin][x_begin] end # 累積和を返します。 # # [y_begin, y_end), [x_begin, x_end) で指定します。 # # 範囲内に要素が存在しない場合 nil を返します。 # # NOTE: このAPIは非推奨です。Rangeで指定することが推奨されます。 def get?(y_begin : Int, y_end : Int, x_begin : Int, x_end : Int) : T? return nil unless y_begin <= y_end return nil unless x_begin <= x_end @csum[y_end][x_end] - @csum[y_begin][x_end] - @csum[y_end][x_begin] + @csum[y_begin][x_begin] end # 累積和を取得します。 # # Range(y_begin, y_end), Range(x_begin, x_end) で指定します。 # # ``` # csum = StaticRectangleSum.new(a) # csum.get(0...h, j..j + 2) # => 28 # ``` def get(y_range : Range(Int?, Int?), x_range : Range(Int?, Int?)) : T y_begin = (y_range.begin || 0) y_end = (y_range.end || @height) + (y_range.exclusive? || y_range.end.nil? ? 0 : 1) x_begin = (x_range.begin || 0) x_end = (x_range.end || @height) + (x_range.exclusive? || x_range.end.nil? ? 0 : 1) get(y_begin, y_end, x_begin, x_end) end # 累積和を返します。 # # [y_begin, y_end), [x_begin, x_end) で指定します。 # # 範囲内に要素が存在しない場合 nil を返します。 # # ``` # csum = StaticRectangleSum.new(a) # csum.get?(0...h, j..j + 2) # => 28 # csum.get?(0...100*h, j..j + 2) # => nil # ``` def get?(y_range : Range(Int?, Int?), x_range : Range(Int?, Int?)) : T? y_begin = (y_range.begin || 0) y_end = (y_range.end || @height) + (y_range.exclusive? || y_range.end.nil? ? 0 : 1) x_begin = (x_range.begin || 0) x_end = (x_range.end || @height) + (x_range.exclusive? || x_range.end.nil? ? 0 : 1) get?(y_begin, y_end, x_begin, x_end) end def [](y_range : Range(Int?, Int?), x_range : Range(Int?, Int?)) : T get(y_range, x_range) end def []?(y_range : Range(Int?, Int?), x_range : Range(Int?, Int?)) : T? get?(y_range, x_range) end end end module NgLib # `SuccinctBitVector` は簡易ビットベクトル(簡易辞書、Succinct Indexable Dictionary)を提供するクラスです。 # # 前計算 $O(n / 32)$ で次の操作が $O(1)$ くらいでできます。 # # - `.[i]` # => $i$ 番目のビットにアクセスする ($O(1)$) # - `.sum(r)` # => $[0, r)$ にある $1$ の個数を求める ($O(1)$) # - `.kth_bit_index(k)` # => $k$ 番目に現れる $1$ の位置を求める ($O(\log{n})$) # # 例えばこの問題が解けます → [D - Sleep Log](https://atcoder.jp/contests/abc305/tasks/abc305_d) class SuccinctBitVector getter size : UInt32 @blocks : UInt32 @bits : Array(UInt32) @sums : Array(UInt32) @n_zeros : Int32 @n_ones : Int32 # 長さ $n$ のビット列を構築します。 # # 計算量は $O(n / 32)$ です。 def initialize(n : Int) @size = n.to_u32 @blocks = (@size >> 5) + 1 @bits = [0_u32] * @blocks @sums = [0_u32] * @blocks @n_zeros = 0 @n_ones = 0 end # 長さ $n$ のビット列を構築します。 # # ブロックでは $i$ 番目のビットの値を返してください。 # # 計算量は $O(n)$ です。 def initialize(n : Int, & : -> Int) @size = n.to_u32 @blocks = (@size >> 5) + 1 @bits = [0_u32] * @blocks @sums = [0_u32] * @blocks @size.times do |i| set i if (yield i) == 1 end @n_zeros = 0 @n_ones = 0 build end # 左から $i$ 番目のビットを 1 にします。 # # 計算量は $O(1)$ です。 def set(i : Int) @bits[i >> 5] |= 1_u32 << (i & 31) end # 左から $i$ 番目のビットを 0 にします。 # # 計算量は $O(1)$ です。 def reset(i : Int) @bits[i >> 5] &= ~(1_u32 << (i & 31)) end # 総和が計算できるようにデータ構造を構築します。 def build @sums[0] = 0 (1...@blocks).each do |i| @sums[i] = @sums[i - 1] + @bits[i - 1].popcount end @n_zeros = @size.to_i - sum @n_ones = @size.to_i - @n_zeros end # $[0, n)$ の総和を返します。 def sum sum(size) end # $[0, r)$ の総和を返します。 def sum(r) : UInt32 @sums[r >> 5] + (@bits[r >> 5] & ((1_u32 << (r & 31)) - 1)).popcount end # $[l, r)$ の総和を返します。 def sum(l, r) : UInt32 sum(r) - sum(l) end # `range` の範囲で総和を返します。 def sum(range : Range(Int?, Int?)) l = (range.begin || 0) r = (range.end || @size) + (range.exclusive? || range.end.nil? ? 0 : 1) sum(l, r) end def count_zeros : Int32 @n_zeros end def count_zeros(range : Range(Int?, Int?)) : Int32 l = (range.begin || 0) r = (range.end || @size) + (range.exclusive? || range.end.nil? ? 0 : 1) (l...r).size.to_i - sum(l, r).to_i end def count_ones : Int32 @n_ones end def count_ones(range : Range(Int?, Int?)) : Int32 l = (range.begin || 0) r = (range.end || @size) + (range.exclusive? || range.end.nil? ? 0 : 1) (l...r).size.to_i - count_zeros(range) end # $i$ 番目のビットを返します。 def [](i) : UInt32 ((@bits[i >> 5] >> (i & 31)) & 1) > 0 ? 1_u32 : 0_u32 end # `range` の範囲の総和を返します。 def [](range : Range(Int?, Int?)) : UInt32 sum(range) end # $k$ 番目に出現する $1$ の位置を求めます。 # # 言い換えると、$sum(i) = k$ となるような最小の $i$ を返します。 # # 存在しない場合は `nil` を返します。 # # 本当は $O(1)$ にできるらしいですが、面倒なので $O(\log{n})$ です。 def kth_bit_index(k) : UInt32 return 0_u32 if k == 0 return nil if sum < k (0..length).bsearch { |right| sum(right) >= k } || raise "Not found" end end end # require "./succinct_bit_vector.cr" module NgLib # **非負整数列** $A$ に対して、順序に関する様々なクエリに答えます。 # # ジェネリクス T は `#[]` などでの返却値の型を指定するものであって、 # 数列の値は非負整数でなければならないことに注意してください。 # # 基本的には `CompressedWaveletMatrix` の方が高速です。 class WaveletMatrix(T) include Indexable(T) @values : Array(UInt64) @n_bits : Int32 @bit_vectors : Array(NgLib::SuccinctBitVector) delegate size, to: @values # 非負整数列 $A$ を `values` によって構築します。 # # ``` # WaveletMatrix.new([1, 3, 2, 5]) # ``` def initialize(values : Array(T)) n = values.size @values = values.map(&.to_u64) @n_bits = log2_floor({1_u64, @values.max}.max) + 1 @bit_vectors = Array.new(@n_bits) { NgLib::SuccinctBitVector.new(n) } cur = @values.clone nxt = Array.new(n) { 0_u64 } (@n_bits - 1).downto(0) do |height| n.times do |i| @bit_vectors[height].set(i) if cur[i].bit(height) == 1 end @bit_vectors[height].build indices = [0, @bit_vectors[height].count_zeros] n.times do |i| bit = @bit_vectors[height][i] nxt[indices[bit]] = cur[i] indices[bit] += 1 end cur, nxt = nxt, cur end end # 長さ $n$ の非負整数列 $A$ を構築します。 # # ``` # WaveletMatrix.new(10) { |i| (5 - i) ** 2 } # ``` def self.new(n : Int, & : Int32 -> T) WaveletMatrix.new(Array.new(n) { |i| yield i }) end # $A$ の `index` 番目の要素を取得します。 # # ``` # wm = WaveletMatrix.new([1, 3, 2, 5]) # wm.unsafe_fetch(0) # => 1 # wm.unsafe_fetch(1) # => 3 # wm.unsafe_fetch(2) # => 2 # wm.unsafe_fetch(3) # => 5 # ``` def unsafe_fetch(index : Int) @values.unsafe_fetch(index) end # `kth` 番目に小さい値を返します。 # # ``` # wm = WaveletMatrix.new([1, 3, 2, 5]) # wm.kth_smallest(0) # => 1 # wm.kth_smallest(1) # => 2 # wm.kth_smallest(2) # => 3 # wm.kth_smallest(3) # => 5 # ``` def kth_smallest(kth : Int32) kth_smallest(.., kth) end # `range` の表す区間に含まれる要素のうち、`kth` 番目に小さい値を返します。 # # ``` # wm = WaveletMatrix.new([1, 3, 2, 5]) # wm.kth_smallest(1..2, 0) # => 2 # wm.kth_smallest(1..2, 1) # => 3 # ``` def kth_smallest(range : Range(Int?, Int?), kth : Int) l = (range.begin || 0) r = (range.end || size) + (range.exclusive? || range.end.nil? ? 0 : 1) ret = T.zero (@n_bits - 1).downto(0) do |height| lzeros, rzeros = succ0(l, r, height) if kth < rzeros - lzeros l, r = lzeros, rzeros else kth -= rzeros - lzeros ret |= T.zero.succ << height l += @bit_vectors[height].count_zeros - lzeros r += @bit_vectors[height].count_zeros - rzeros end end ret end # `kth` 番目に大きい値を返します。 # # ``` # wm = WaveletMatrix.new([1, 3, 2, 5]) # wm.kth_smallest(0) # => 5 # wm.kth_smallest(1) # => 3 # wm.kth_smallest(2) # => 2 # wm.kth_smallest(3) # => 1 # ``` def kth_largest(kth : Int32) kth_largest(.., kth) end # `range` の表す区間に含まれる要素のうち、`kth` 番目に大きい値を返します。 # # ``` # wm = WaveletMatrix.new([1, 3, 2, 5]) # wm.kth_largest(1..2, 0) # => 3 # wm.kth_largest(1..2, 1) # => 2 # ``` def kth_largest(range : Range(Int?, Int?), kth : Int) kth_smallest(range, range.size - kth - 1) end # `item` の個数を返します。 # # 計算量は対数オーダーです。 def count(item : T) count(.., item) end # `range` の表す区間に含まれる要素の中で `item` の個数を返します。 def count(range : Range(Int?, Int?), item : T) count(range, item..item) end # `range` の表す区間に含まれる要素の中で `bound` が表す範囲の値の個数を返します。 def count(range : Range(Int?, Int?), bound : Range(T?, T?)) : Int32 lower_bound = (bound.begin || 0) upper_bound = (bound.end || T::MAX) + (bound.exclusive? || bound.end.nil? ? 0 : 1) count_less_eq(range, upper_bound) - count_less_eq(range, lower_bound) end # `range` の表す区間に含まれる要素のうち、`upper_bound` **未満** の値の最大値を返します。 # # 存在しない場合は `nil` を返します。 def prev_value(range : Range(Int?, Int?), upper_bound) : T? cnt = count_less_eq(range, upper_bound) cnt == 0 ? nil : kth_smallest(range, cnt - 1) end # `range` の表す区間に含まれる要素のうち、`lower_bound` **以上** の値の最小値を返します。 # # 存在しない場合は `nil` を返します。 def next_value(range : Range(Int?, Int?), lower_bound) : T? l = (range.begin || 0) r = (range.end || size) + (range.exclusive? || range.end.nil? ? 0 : 1) cnt = count_less_eq(range, lower_bound) cnt == (l...r).size ? nil : kth_smallest(range, cnt) end # `range` の表す区間に含まれる要素のうち、`upper_bound` **未満** の値の個数を返します。 # # 存在しない場合は `nil` を返します。 def count_less_eq(range : Range(Int?, Int?), upper_bound) : Int32 l = (range.begin || 0) r = (range.end || size) + (range.exclusive? || range.end.nil? ? 0 : 1) return (l...r).size.to_i if upper_bound >= (T.zero.succ << @n_bits) ret = 0 (@n_bits - 1).downto(0) do |height| lzeros, rzeros = succ0(l, r, height) if upper_bound.bit(height) == 1 ret += rzeros - lzeros l += @bit_vectors[height].count_zeros - lzeros r += @bit_vectors[height].count_zeros - rzeros else l, r = lzeros, rzeros end end ret end @[AlwaysInline] private def log2_floor(n : UInt64) : Int32 log2_floor = 63 - n.leading_zeros_count log2_floor + ((n & n - 1) == 0 ? 0 : 1) end @[AlwaysInline] private def succ0(left : Int, right : Int, height : Int) lzeros = left <= 0 ? 0 : @bit_vectors[height].count_zeros(0...Math.min(left, size)) rzeros = right <= 0 ? 0 : @bit_vectors[height].count_zeros(0...Math.min(right, size)) {lzeros, rzeros} end end class CompressedWaveletMatrix(T) include Indexable(T) @wm : WaveletMatrix(Int32) @uniqued : Array(T) delegate size, to: @wm def initialize(values : Array(T)) @uniqued = values.sort.uniq! @wm = WaveletMatrix.new(values.map { |elem| get_index(elem) }) end def self.new(n : Int, & : Int32 -> T) CompressedWaveletMatrix.new(Array.new(n) { |i| yield i }) end def unsafe_fetch(index : Int) @uniqued[@wm.unsafe_fetch(index)] end # `kth` 番目に小さい値を返します。 # # ``` # wm = WaveletMatrix.new([1, 3, 2, 5]) # wm.kth_smallest(0) # => 1 # wm.kth_smallest(1) # => 2 # wm.kth_smallest(2) # => 3 # wm.kth_smallest(3) # => 5 # ``` def kth_smallest(kth : Int32) kth_smallest(.., kth) end # `range` の表す区間に含まれる要素のうち、`kth` 番目に小さい値を返します。 # # ``` # wm = WaveletMatrix.new([1, 3, 2, 5]) # wm.kth_smallest(1..2, 0) # => 2 # wm.kth_smallest(1..2, 1) # => 3 # ``` def kth_smallest(range : Range(Int?, Int?), kth : Int) @uniqued[@wm.kth_smallest(range, kth)] end # `kth` 番目に大きい値を返します。 # # ``` # wm = WaveletMatrix.new([1, 3, 2, 5]) # wm.kth_smallest(0) # => 5 # wm.kth_smallest(1) # => 3 # wm.kth_smallest(2) # => 2 # wm.kth_smallest(3) # => 1 # ``` def kth_largest(kth : Int32) kth_largest(.., kth) end # `range` の表す区間に含まれる要素のうち、`kth` 番目に大きい値を返します。 # # ``` # wm = WaveletMatrix.new([1, 3, 2, 5]) # wm.kth_largest(1..2, 0) # => 3 # wm.kth_largest(1..2, 1) # => 2 # ``` def kth_largest(range : Range(Int?, Int?), kth : Int) kth_smallest(range, range.size - kth - 1) end # `item` の個数を返します。 # # 計算量は対数オーダーです。 def count(item : T) count(.., item) end # `range` の表す区間に含まれる要素の中で `item` の個数を返します。 def count(range : Range(Int?, Int?), item : T) count(range, item..item) end # `range` の表す区間に含まれる要素の中で `bound` が表す範囲の値の個数を返します。 def count(range : Range(Int?, Int?), bound : Range(T?, T?)) : Int32 lower_bound = (bound.begin || 0) upper_bound = (bound.end || T::MAX) + (bound.exclusive? || bound.end.nil? ? 0 : 1) count_less_eq(range, upper_bound) - count_less_eq(range, lower_bound) end # `range` の表す区間に含まれる要素のうち、`upper_bound` **未満** の値の最大値を返します。 # # 存在しない場合は `nil` を返します。 def prev_value(range : Range(Int?, Int?), upper_bound) : T? cnt = count_less_eq(range, upper_bound) cnt == 0 ? nil : kth_smallest(range, cnt - 1) end # `range` の表す区間に含まれる要素のうち、`lower_bound` **以上** の値の最小値を返します。 # # 存在しない場合は `nil` を返します。 def next_value(range : Range(Int?, Int?), lower_bound) : T? l = (range.begin || 0) r = (range.end || size) + (range.exclusive? || range.end.nil? ? 0 : 1) cnt = count_less_eq(range, lower_bound) cnt == (l...r).size ? nil : kth_smallest(range, cnt) end private def get_index(value : T) @uniqued.bsearch_index { |x| x >= value } || @uniqued.size end private def count_less_eq(range : Range(Int?, Int?), upper_bound) : Int32 @wm.count_less_eq(range, get_index(upper_bound)) end end end # require "../constants" module NgLib # $n$ 頂点 $m$ 辺からなるグラフに対して、幅優先探索によって最短経路を求めます。 # # 経路の復元も可能です。 class BfsGraph getter size : Int32 getter graph : Array(Array(Int32)) # $n$ 頂点 $0$ 辺からなるグラフを作成します。 # # ``` # graph = BfsGraph.new(n) # ``` def initialize(n : Int) @size = n.to_i64.to_i32 @graph = Array.new(@size) { Array(Int32).new } end # 辺 $(u, v)$ を追加します。 # # `directed` が `true` の場合、 # 有向辺とみなして、$u$ から $v$ への辺のみ生やします。 # # ``` # graph = BfsGraph.new(n) # graph.add_edge(u, v) # => (u) <---w---> (v) # graph.add_edge(u, v, directed: true) # => (u) ----w---> (v) # ``` def add_edge(u : Int, v : Int, directed : Bool = true) @graph[u.to_i32] << v.to_i32 @graph[v.to_i32] << u.to_i32 unless directed end # 全点対間の最短経路長を返します。 # # ``` # dists = graph.shortest_path # dists # => [[0, 1, 3], [1, 0, 2], [1, 1, 0]] # ``` def shortest_path (0...@size).map { |start| shortest_path(start) } end # 始点 `start` から各頂点への最短経路長を返します。 # # ``` # dist = graph.shortest_path(start: 2) # dist # => [3, 8, 0, 7, 1] # ``` def shortest_path(start : Int) queue = Deque.new([start.to_i32]) dist = Array.new(@size) { |i| i == start ? 0_i64 : OO } until queue.empty? from = queue.shift @graph[from].each do |adj| next if dist[adj] != OO dist[adj] = dist[from] + 1 queue << adj end end dist end # 始点 `start` から終点 `dest` への最短経路長を返します。 # # ``` # dist = graph.shortest_path(start: 2, dest: 0) # dist # => 12 # ``` def shortest_path(start : Int, dest : Int) shortest_path(start)[dest] end end end # require "../constants" module NgLib class BinaryBfsGraph private struct Edge getter to : Int32 getter weight : Int32 def initialize(t : Int, w : Int) @to = t.to_i32 @weight = w.to_i32 end end getter size : Int32 getter graph : Array(Array(Edge)) # $n$ 頂点 $0$ 辺からなるグラフを作成します。 # # ``` # graph = BfsGraph.new(n) # ``` def initialize(n : Int) @size = n.to_i64.to_i32 @graph = Array.new(@size) { [] of Edge } end # 辺 $(u, v, w)$ を追加します。 # # $w$ は $0$ または $1$ である必要があります。 # # `directed` が `true` の場合、 # 有向辺とみなして、$u$ から $v$ への辺のみ生やします。 # # ``` # graph = BfsGraph.new(n) # graph.add_edge(u, v) # => (u) <---w---> (v) # graph.add_edge(u, v, directed: true) # => (u) ----w---> (v) # ``` def add_edge(u : Int, v : Int, w : Int, directed : Bool = true) raise Exception.new("w should be 0 or 1") unless w.in?({0, 1}) @graph[u.to_i32] << Edge.new(v, w) @graph[v.to_i32] << Edge.new(u, w) unless directed end # 全点対間の最短経路長を返します。 # # ``` # dists = graph.shortest_path # dists # => [[0, 1, 3], [1, 0, 2], [1, 1, 0]] # ``` def shortest_path (0...@size).map { |start| shortest_path(start) } end # 始点 `start` から各頂点への最短経路長を返します。 # # ``` # dist = graph.shortest_path(start: 2) # dist # => [3, 8, 0, 7, 1] # ``` def shortest_path(start : Int) deque = Deque.new([start.to_i32]) dist = Array.new(@size) { |i| i == start ? 0_i64 : OO } until deque.empty? from = deque.shift @graph[from].each do |e| d = dist[from] + e.weight if d < dist[e.to] dist[e.to] = d if e.weight == 0 deque.unshift(e.to) else deque << e.to end end end end dist end # 始点 `start` から終点 `dest` への最短経路長を返します。 # # ``` # dist = graph.shortest_path(start: 2, dest: 0) # dist # => 12 # ``` def shortest_path(start : Int, dest : Int) shortest_path(start)[dest] end end end require "atcoder/priority_queue" module NgLib abstract struct Weight include Comparable(Weight) def self.zero : self end def self.inf : self end def +(other : self) end def <=>(other : self) end end # $n$ 頂点 $m$ 辺の重み付きグラフに対して、最短経路を求めます。 # # 経路の復元も可能です。 # # 辺の重みが非負整数で表せる場合は `nglib/graph/radix_dijkstra` を使ったほうが高速です。 class DijkstraGraph(Weight) record Edge(W), target : Int32, weight : W getter size : Int32 @graph : Array(Array(Edge(Weight))) # $n$ 頂点 $0$ 辺からなるグラフを作成します。 # # ``` # graph = Dijkstra.new(n) # ``` def initialize(n : Int) @size = n.to_i32 @graph = Array.new(@size) { Array(Edge(Weight)).new } end # 非負整数の重み $w$ の辺 $(u, v)$ を追加します。 # # `directed` が `true` の場合、 # 有向辺とみなして、$u$ から $v$ への辺のみ生やします。 # # ``` # graph = Dijkstra.new(n) # graph.add_edge(u, v, w) # => (u) <---w---> (v) # graph.add_edge(u, v, w, directed: true) # => (u) ----w---> (v) # ``` def add_edge(u : Int, v : Int, w : Weight, directed : Bool = true) @graph[u.to_i32] << Edge.new(v.to_i32, w) @graph[v.to_i32] << Edge.new(u.to_i32, w) unless directed end # 全点対間の最短経路長を返します。 # # ``` # dists = graph.shortest_path # dists # => [[0, 1, 3], [1, 0, 2], [1, 1, 0]] # ``` def shortest_path : Array(Array(Weight)) (0...@size).map { |start| shortest_path(start) } end # 始点 `start` から各頂点への最短経路長を返します。 # # ``` # dist = graph.shortest_path(2) # dist # => [3, 8, 0, 7, 1] # ``` def shortest_path(start : Int) : Array(Weight) dist = [Weight.inf] * @size dist[start] = Weight.zero next_node = AtCoder::PriorityQueue({Weight, Int32}).min next_node << {Weight.zero, start.to_i32} until next_node.empty? d, source = next_node.pop.not_nil! next if dist[source] < d @graph[source].each do |e| next_cost = dist[source] + e.weight if next_cost < dist[e.target] dist[e.target] = next_cost next_node << {next_cost, e.target} end end end dist end # 始点 `start` から終点 `dest` への最短経路長を返します。 # # ``` # dist = graph.shortest_path(start: 2, dest: 0) # dist # => 12 # ``` def shortest_path(start : Int, dest : Int) : Weight shortest_path(start)[dest] end # 始点 `start` から終点 `dest` への最短経路の一例を返します。 # # ``` # route = graph.shortest_path_route(start: 2, dest: 0) # route # => [2, 7, 1, 0] # ``` def shortest_path_route(start, dest) prev = impl_memo_route(start) res = Array(Int32).new now : Int32? = dest.to_i32 until now.nil? res << now.not_nil! now = prev[now] end res.reverse end # 始点 `start` から最短路木を構築します。 # # 最短路木は `start` からの最短経路のみを残した全域木です。 # # ``` # route = graph.shortest_path_route(start: 2, dest: 0) # route # => [2, 7, 1, 0] # ``` def shortest_path_tree(start, directed : Bool = true) : Array(Array(Int32)) dist = [Weight.inf] * @size dist[start] = Weight.zero next_node = AtCoder::PriorityQueue({Weight, Int32}).min next_node << {Weight.zero, start.to_i32} birth = [-1] * @size until next_node.empty? d, source = next_node.pop.not_nil! next if dist[source] < d @graph[source].each do |e| next_cost = dist[source] + e.weight if next_cost < dist[e.target] dist[e.target] = next_cost next_node << {next_cost, e.target} birth[e.target] = source end end end tree = Array.new(@size) { [] of Int32 } @size.times do |target| source = birth[target] next if source == -1 tree[source] << target tree[target] << source unless directed end tree end # 経路復元のための「どこから移動してきたか」を # メモした配列を返します。 private def impl_memo_route(start) dist = [Weight.inf] * @size dist[start] = Weight.zero prev = Array(Int32?).new(@size) { nil } next_node = AtCoder::PriorityQueue({Weight, Int32}).min next_node << {Weight.zero, start.to_i32} until next_node.empty? d, source = next_node.pop.not_nil! next if dist[source] < d @graph[source].each do |e| next_cost = dist[source] + e.weight if next_cost < dist[e.target] dist[e.target] = next_cost prev[e.target] = source next_node << {next_cost, e.target} end end end prev end end end # require "../data_structure/sparse_table.cr" module NgLib class EulerTourTree @size : Int32 @graph : Array(Array(Int32)) @time : Int32 getter itinerary : Array(Int32) getter login : Array(Int32) getter logout : Array(Int32) getter parents : Array(Int32) getter depths : Array(Int32) @min_depth : SparseTable(Int32) # n 頂点 0 辺のグラフを生成します。 # # ``` # tree = EulerTourTree.new(n) # ``` def initialize(n : Int) @size = n.to_i32 @graph = Array.new(n) { [] of Int32 } @time = 0 @itinerary = [] of Int32 @login = [-1] * n @logout = [-1] * n @parents = [-1] * n @depths = [] of Int32 @min_depth = uninitialized SparseTable(Int32) end # 無向辺 (u, v) を追加します。 # # ``` # tree = EulerTourTree.new(n) # tree.add_edge(u, v) # ``` def add_edge(u : Int, v : Int) @graph[u] << v.to_i32 @graph[v] << u.to_i32 end # 実際にオイラーツアーします。 # # itinerary などが正しく構築されます。 # # ``` # tree = EulerTourTree.new(n) # tree.build # tree.build(root: 3) # ``` def build(root : Int = 0) dfs(root.to_i32, -1, 0) @min_depth = SparseTable(Int32).min(@depths) end # 頂点 u と頂点 v の最小共通祖先を返します。 # # ``` # tree = EulerTourTree.new(n) # tree.build # tree.lca(u, v) # ``` def lca(u : Int, v : Int) l = {@login[u], @logout[u], @login[v], @logout[v]}.min r = {@login[u], @logout[u], @login[v], @logout[v]}.max mn = @min_depth[l...r] ok = r ng = l while ok - ng > 1 mid = (ok + ng) // 2 if @min_depth[l...mid] == mn ok = mid else ng = mid end end @itinerary[ok - 1] end # 根を x とする部分木の頂点のコストの総和を求めるためのリストを返します。 # # `zero` には「総和」の単位元を渡してください。 # # リストの i 番目には時刻 i に訪れた頂点のコストが格納されています。 # ただし、その頂点が時刻 i 以前に訪問済みだった場合は zero が格納されます。 # # 頂点 v のコストが w 変更される場合は、 # このリストの @login[v] 番目を w にすると良いです。 # # ``` # tree = EulerTourTree.new(n) # tree.build # node_cost = Array.new(n) # tree.subtree_node_costs { |v| node_costs[v] } # ``` def subtree_node_costs(zero = U.zero, & : Int32 -> U) forall U costs = Array(U).new(itinerary.size) { zero } @size.times do |v| t = @login[v] cost = yield v costs[t] = cost end end # 根を x とする部分木の辺のコストの総和を求めるためのリストを返します。 # # `zero` には「総和」の単位元を渡してください。 # # リストの i 番目には時刻 i に訪れた頂点と、 # 時刻 i - 1 に訪れた頂点を結ぶ辺のコストが格納されています。 # ただし、その辺が時刻 i 以前に訪問済みだった場合は zero が格納されます # # 無向辺 (u, v) のコストが w に変更される場合は、 # 頂点 u, v のうち、**後に**訪れる方の頂点を bwd として、 # このリストの @login[bwd] 番目を w にすると良いです。 # # ``` # tree = EulerTourTree.new(n) # tree.build # edge_cost = Hash({Int32, Int32}, Int64).new # tree.subtree_edge_costs { |u, v| edge_cost[{u, v}] } # ``` def subtree_edge_costs(zero = U.zero, & : Int32, Int32 -> U) forall U costs = Array(U).new(itinerary.size) { zero } @size.times do |v| t = @login[v] next if t == 0 par = @parents[v] costs[t] = yield par, v end end # 根からのパスに現れる頂点のコストの総和を求めるためのリストを返します。 # # リストの i 番目には時刻 i に訪れた頂点のコストが格納されています。 # ただし、その頂点が時刻 i 以前に訪問済みだった場合は、マイナス倍された値が格納されます。 # # 頂点 v のコストが w 変更される場合は、 # このリストの @login[v] 番目を w に、それ以降で v が現れる時刻番目を -w にすると良いです。 # (つまり、計算量がそこそこかかりそう?) # # ``` # tree = EulerTourTree.new(n) # tree.build # tree.node_costs_on_root_path { |v| node_costs[v] } # ``` def node_costs_on_root_path(& : Int32 -> U) forall U root = itinerary[0] costs = Array(U).new(itinerary.size + 1) costs << yield root (itinerary.size - 1).times do |i| s = itinerary[i] t = itinerary[i + 1] if @parents[t] == s costs << yield t else costs << -(yield s) end end costs << -(yield root) end # 根からのパスに現れる辺のコストの総和を求めるためのリストを返します。 # # `zero` には「総和」の単位元を渡してください。 # # リストの i 番目には時刻 i に訪れた頂点と、 # 時刻 i - 1 に訪れた頂点を結ぶ辺のコストが格納されています。 # ただし、その辺が時刻 i 以前に訪問済みだった場合は、マイナス倍された値が格納されます。 # また、時刻 0 と時刻 itinerary.size には zero が格納されます。 # # 辺 (u, v) のコストが w 変更される場合は、 # 頂点 u, v のうち、**後に**訪れる方の頂点 bwd をとして、 # このリストの @login[bwd] 番目を w に、@logout[bwd] 番目を -w にしてください。 # # ``` # tree = EulerTourTree.new(n) # tree.build # tree.edge_costs_on_root_path { |u, v| edge_cost[{u, v}] } # ``` def edge_costs_on_root_path(zero = U.zero, & : Int32, Int32 -> U) forall U costs = Array(U).new(itinerary.size + 1) costs << U.zero (itinerary.size - 1).times do |i| s = itinerary[i] t = itinerary[i + 1] if @parents[t] == s costs << yield s, t else costs << -(yield s, t) end end costs << U.zero end # 根を subroot とする部分木の頂点のコストの総和を返します。 # # ブロックでは整数 l, r が与えられるので、 # subtree_node_costs の [l, r) までの総和を返してください。 # # ``` # tree = EulerTourTree.new(n) # tree.build # node_cost = Array.new(n) # costs = tree.subtree_node_costs { |v| node_costs[v] } # csum = StaticRangeSum.new(costs) # tree.sum_subtree_node_cost(subroot: 3) { |l, r| csum[l...r] } # ``` def sum_subtree_node_cost(subroot : Int, & : Int32, Int32 -> U) forall U l = @login[subroot] r = @logout[subroot] yield l, r end # 根を subroot とする部分木の辺のコストの総和を返します。 # # ブロックでは整数 l, r が与えられるので、 # subtree_edge_costs の [l, r) までの総和を返してください。 # # ``` # tree = EulerTourTree.new(n) # tree.build # edge_cost = Hash({Int32, Int32}, Int64).new # costs = tree.subtree_node_costs { |u, v| edge_cost[{u, v}] } # csum = StaticRangeSum.new(costs) # tree.sum_subtree_edge_cost(subroot: 3) { |l, r| csum[l...r] } # ``` def sum_subtree_edge_cost(subroot : Int, & : Int32, Int32 -> U) forall U l = @login[subroot] r = @logout[subroot] yield l + 1, r end # 根から頂点 v へのパスに現れる頂点のコストの総和を返します。 # # ブロックでは整数 l, r が与えられるので、 # root_path_node_costs の [l, r) までの総和を返してください。 # # ``` # tree = EulerTourTree.new(n) # tree.build # costs = tree.root_path_node_costs { |v| node_costs[v] } # csum = StaticRangeSum.new(costs) # tree.sum_node_cost_on_path_to(v: 3) { |l, r| csum[l...r] } # ``` def sum_node_cost_on_path_to(v : Int, & : Int32, Int32 -> U) forall U r = @login[v] yield 0, r + 1 end # 根から頂点 v へのパスに現れる辺のコストの総和を返します。 # # ブロックでは整数 l, r が与えられるので、 # root_path_edge_costs の [l, r) までの総和を返してください。 # # ``` # tree = EulerTourTree.new(n) # tree.build # costs = tree.root_path_node_costs { |v| node_costs[v] } # csum = StaticRangeSum.new(costs) # tree.sum_edge_cost_on_path_to(v: 3) { |l, r| csum[l...r] } # ``` def sum_edge_cost_on_path_to(v : Int, & : Int32, Int32 -> U) forall U r = @login[v] yield 1, r + 1 end private def dfs(v : Int32, par : Int32, d : Int32) @itinerary << v @depths << d @parents[v] = par @login[v] = @time @time += 1 @graph[v].each do |child| next if child == par dfs(child, v, d + 1) @itinerary << v @depths << d end @logout[v] = @time @time += 1 end end end module NgLib # ワーシャル・フロイド法の実装です。 # # (負を含む)重み付きグラフに対して、 # 全点対最短経路長が $O(V^3)$ で求まります。 class FloydWarshallGraph(T) getter size : Int32 getter mat : Array(Array(T?)) # $n$ 頂点 $0$ 辺のグラフを作ります。 # # ``` # n = 10 # NgLib::FloydWarshallGraph(Int64).new(n) # ``` def initialize(n : Int) @size = n.to_i32 @mat = Array.new(n) { Array.new(n) { nil.as(T?) } } @size.times do |i| @mat[i][i] = T.zero end end # 隣接行列に従ってグラフを作ります。 # # `nil` は辺が存在しないことを表します。 # 無限大の重みを持つ辺と捉えても良いです。 # # ``` # mat = [[0, 3, 1], [-2, 0, 4], [nil, nil, 0]] # NgLib::FloydWarshallGraph(Int32).new(mat) # ``` def initialize(@mat : Array(Array(T?))) @size = @mat.size @size.times do |i| @mat[i][i] = T.zero end end # :ditto: def initialize(matrix : Array(Array(T?) | Array(T))) @mat = matrix.map { |line| line.map { |v| v.as(T?) } } @size = @mat.size @size.times do |i| @mat[i][i] = T.zero end end # 重みが $w$ の辺 $(u, v)$ を追加します。 # # `directed` が `true` である場合、有向辺として追加します。 # # ``` # n, m = read_line.split.map &.to_i # graph = NgLib::FloydWarshallGraph.new(n) # m.times do # u, v, w = read_line.split.map &.to_i64 # u -= 1; v -= 1 # 0-index # graph.add_edge(u, v, w, directed: true) # end # ``` def add_edge(u : Int, v : Int, w : T, directed : Bool = true) uv = @mat[u][v] if uv.nil? @mat[u][v] = w else @mat[u][v] = {uv, w}.min end unless directed vu = @mat[v][u] if vu.nil? @mat[v][u] = w else @mat[v][u] = {vu, w}.min end end end # 全点対最短経路長を返します。 # # どのような経路を辿っても到達できない場合は `nil` が格納されます。 # # ``` # mat = [[0, 3, 1], [-2, 0, 4], [nil, nil, 0]] # graph = NgLib::FloydWarshallGraph.new(mat) # d = graph.shortest_path # => [[0, 3, 1], [-2, 0, -1], [nil, nil, 0]] # d[0][1] # => 3 (i から j への最短経路長) # ``` def shortest_path dist = @mat.clone @size.times do |via| @size.times do |from| @size.times do |dest| d1 = dist[from][via] d2 = dist[via][dest] next if d1.nil? next if d2.nil? d = dist[from][dest] if d.nil? || d > d1 + d2 dist[from][dest] = d1 + d2 end end end end dist end end end # require "../constants.cr" module NgLib # 最近共通祖先を求めるライブラリです。 class LCA alias Graph = Array(Array(Int64)) getter parent : Array(Array(Int64)) getter dist : Array(Int64) @graph : Graph # 木構造グラフ `graph` に対して、`root` を根とする LCA を構築します。 def initialize(@graph : Graph, root = 0_i64) n = graph.size k = 1_i64 while (1_i64 << k) < n k += 1 end @parent = Array.new(k) { [-1_i64] * n } @dist = [OO] * n dfs(root, -1_i64, 0_i64) (k - 1).times do |i| n.times do |v| if @parent[i][v] < 0 @parent[i + 1][v] = -1_i64 else @parent[i + 1][v] = @parent[i][@parent[i][v]] end end end end # 頂点 $u$ と 頂点 $v$ の最近共通祖先を返します。 def ancestor(u : Int, v : Int) : Int64 if @dist[u] < @dist[v] u, v = v, u end n = @parent.size n.times do |k| u = @parent[k][u] if (dist[u] - dist[v]).bit(k) == 1 end return u if u == v (n - 1).downto(0) do |k| if @parent[k][u] != @parent[k][v] u, v = @parent[k][u], @parent[k][v] end end @parent[0][u] end # 頂点 $u$ と頂点 $v$ の距離を返します。 def distance_between(u : Int, v : Int) : Int64 dist[u] + dist[v] - dist[ancestor(u, v)] * 2 end # 頂点 $u$ から頂点 $v$ までのパスに頂点 $a$ が含まれているか返します。 def on_path?(u : Int, v : Int, a : Int) : Bool distance_between(u, a) + distance_between(a, v) == distance_between(u, v) end private def dfs(root : Int64, par : Int64, d : Int64) @parent[0][root] = par @dist[root] = d @graph[root].each do |child| next if child == par dfs(child, root, d + 1) end end end end module NgLib class MaxFlowGraph(Cap) class Edge(T) getter to : Int32 getter rev : Int32 property cap : T def initialize(@to, @rev, @cap) end end getter size : Int32 @graph : Array(Array(Edge(Cap))) @pos : Array({Int32, Int32}) # $0$ 頂点 $0$ 辺のグラフを作ります。 # # ``` # graph = MaxFlowGraph(Int64).new # ``` def initialize @size = 0 @graph = [] of Array(Edge(Cap)) @pos = [] of {Int32, Int32} end # $n$ 頂点 $0$ 辺のグラフを作ります。 # # ``` # n = 10 # graph = MaxFlowGraph(Int64).new(n) # ``` def initialize(n : Int) @size = n.to_i32 @graph = Array.new(n) { Array(Edge(Cap)).new } @pos = [] of {Int32, Int32} end # 頂点 `from` から頂点 `to` へ最大容量 `cap`、流量 $0$ の辺を追加します。 # # 何番目に追加された辺であるかを返します。 # # ``` # n = 10 # graph = MaxFlowGraph(Int64).new(n) # graph.add_edge(0, 1, 1) # => 0 # graph.add_edge(1, 3, 2) # => 1 # graph.add_edge(5, 6, 8) # => 2 # ``` def add_edge(from : Int, to : Int, cap : Cap) : Int32 m = @pos.size @pos << {from.to_i32, @graph[from].size} from_id = @graph[from].size to_id = @graph[to].size to_id += 1 if from == to @graph[from] << Edge(Cap).new(to.to_i32, to_id, cap) @graph[to] << Edge(Cap).new(from.to_i32, from_id, Cap.zero) m end # 今の内部の辺の状態を返します。 # # 辺の順番は `add_edge` での追加順と同じです。 def get_edge(i : Int) e = @graph[@pos[i][0]][@pos[i][1]] re = @graph[e.to][e.rev] {from: @pos[i][0], to: e.to, cap: e.cap + re.cap, flow: re.cap} end # 今の内部の辺の状態を返します。 # # 辺の順番は `add_edge` での追加順と同じです。 def edges Array.new(@pos.size) { |i| get_edge(i) } end # $i$ 番目に変更された辺の容量、流量をそれぞれ `new_cap`, `new_flow` に変更します。 # # 他の辺の容量、流量は変更しません。 def change_edge(i : Int, new_cap : Cap, new_flow : Cap) @graph[@pos[i].first][@pos[i].second].cap = new_cap - new_flow @graph[_e.to][_e.rev].cap = new_flow end # 頂点 $s$ から頂点 $t$ へ流せるだけ流し、流せた量を返します。 # # 複数回呼ぶことも可能ですが、同じ結果を返すわけではありません。 # 挙動については以下を参考にしてください。 # - https://atcoder.github.io/ac-library/document_ja/appendix.html # # ``` # n = 4 # graph = MaxFlowGraph(Int64).new(n) # graph.add_edge(0, 1, 10) # => 0 # graph.add_edge(1, 2, 2) # => 1 # graph.add_edge(0, 2, 5) # => 2 # graph.add_edge(1, 3, 6) # => 3 # graph.add_edge(2, 3, 3) # => 4 # # graph.flow(0, 3) # => 9 # ``` def flow(s : Int, t : Int) flow(s, t, Cap::MAX) end # 頂点 $s$ から頂点 $t$ へ流せるだけ流し、流せた量を返します。 # # 複数回呼ぶことも可能ですが、同じ結果を返すわけではありません。 # 挙動については以下を参考にしてください。 # - https://atcoder.github.io/ac-library/document_ja/appendix.html # # ``` # n = 4 # graph = MaxFlowGraph(Int64).new(n) # graph.add_edge(0, 1, 10) # => 0 # graph.add_edge(1, 2, 2) # => 1 # graph.add_edge(0, 2, 5) # => 2 # graph.add_edge(1, 3, 6) # => 3 # graph.add_edge(2, 3, 3) # => 4 # # graph.flow(0, 3, 6) # => 6 # graph.flow(0, 3, 100) # => 9 (本来の挙動であれば 3 を返します。) # ``` # ameba:disable Metrics/CyclomaticComplexity def flow(s : Int, t : Int, flow_limit : Cap) level = [0] * @size iter = [0] * @size bfs = ->{ level = [-1] * @size level[s] = 0 que = Deque(Int32).new([s.to_i32]) until que.empty? v = que.shift @graph[v].each do |e| next if e.cap == 0 || level[e.to] >= 0 level[e.to] = level[v] + 1 next if e.to == t que << e.to end end } dfs = uninitialized Int32, Cap -> Cap dfs = ->(v : Int32, up : Cap) { return up if v == s res = Cap.zero level_v = level[v] (iter[v]...@graph[v].size).each do |i| e = @graph[v][i] next if level_v <= level[e.to] || @graph[e.to][e.rev].cap == 0 d = dfs.call(e.to, Math.min(up - res, @graph[e.to][e.rev].cap)) next if d <= 0 @graph[v][i].cap += d @graph[e.to][e.rev].cap -= d res += d return res if res == up end level[v] = @size res } res = Cap.zero while res < flow_limit bfs.call break if level[t] == -1 iter = [0] * @size f = dfs.call(t.to_i32, flow_limit - res) break if f == 0 res += f end res end # 長さ $n$ の配列を返します。 # $i$ 番目の要素には、頂点 $s$ から $i$ へ残余グラフで到達可能なとき、またその時のみ `true` を返します。 # # `flow(s, t)` を `flow_limit` なしでちょうど一回呼んだ後に呼ぶと、 # 返り値は $s, t$ 間の `mincut` に対応します。 def min_cut(s : Int) closed = [false] * @size que = Deque(Int32).new([s.to_i32]) unless que.empty? now = que.shift closed[now] = true @graph[now].each do |e| if e.cap != 0 && !closed[e.to] closed[e.to] = true que << e.to end end end closed end end end require "atcoder/dsu" module NgLib # $n$ 頂点の重み付きグラフについて、最小/最大全域木を構築します。 # # Kruskal 法による実装です。 class MSTGraph(T) getter size : Int64 @edges : Array({Int32, Int32, T}) def initialize(n) @size = n.to_i64 @edges = [] of {Int32, Int32, T} @cmp = ->(a : T, b : T) { a <=> b } end def initialize(n, &@cmp : (T, T) -> Int32) @size = n.to_i64 @edges = [] of {Int32, Int32, T} end # $n$ 頂点 $0$ 辺のグラフを生成します。 # # 最小全域木を構築します。 def self.min(n) new(n) { |lhs, rhs| lhs <=> rhs } end # $n$ 頂点 $0$ 辺のグラフを生成します。 # # 最大全域木を構築します。 def self.max(n) new(n) { |lhs, rhs| rhs <=> lhs } end # グラフに辺 $(u, v, w)$ を追加します。 # # ``` # graph = MSTGraph(Int64).new(n) { |a, b| a < b } # m.times { graph.add_edge(u, v, w) } # ``` def add_edge(u : Int, v : Int, w : T) @edges << {u.to_i32, v.to_i32, w} end # 最小全域木を構成したときの辺の重みの総和求めます。 # # ``` # graph = MSTGraph(Int64).new(n) { |a, b| a < b } # m.times { graph.add_edge(u, v, w) } # graph.sum # ``` def sum @edges.sort! { |lhs, rhs| @cmp.call(lhs[2], rhs[2]) } ut = AtCoder::DSU.new(@size) res = T.zero @edges.each do |(u, v, w)| next if ut.same?(u, v) ut.merge(u, v) res += w end res end end end # require "../constants.cr" module NgLib # $n$ 頂点 $m$ 辺の重み付きグラフに対して、最短経路を求めます。 # # 経路の復元も可能です。 # # このクラスは辺の重みが非負整数であるときのみ使えます。 # 辺の重みに非負整数以外を使いたい場合は `nglib/graph/dijkstra` を `require` してください。 class DijkstraGraph record Edge, target : Int32, weight : UInt64 # 基数ヒープ private class RadixHeap64(T) @s : Int32 @last : UInt64 @bit : Int32 @vs : Array(Array({UInt64, T})) @ms : Array(UInt64) def initialize @s = 0 @last = 0_u64 @bit = sizeof(UInt64) * 8 @vs = Array.new(@bit + 1) { [] of {UInt64, T} } @ms = Array.new(@bit + 1) { -1.to_u64! } end def empty? : Bool @s == 0 end def size : Int32 s end @[AlwaysInline] def get_bit(x : UInt64) : UInt64 64_u64 - x.leading_zeros_count end def push(key : UInt64, val : T) : Nil @s += 1 b = get_bit(key ^ @last) @vs[b] << {key, val} @ms[b] = Math.min(@ms[b], key) end def pop : {UInt64, T} if @ms[0] == -1.to_u64! idx = @ms.index! { |elem| elem != -1.to_u64! } @last = @ms[idx] @vs[idx].each do |v| b = get_bit(v[0] ^ @last) @vs[b] << v @ms[b] = Math.min(@ms[b], v[0]) end @vs[idx].clear @ms[idx] = -1.to_u64! end @s -= 1 res = @vs[0].last @vs[0].pop @ms[0] = -1.to_u64! if @vs[0].empty? res end end getter size : Int32 @graph : Array(Array(Edge)) # $n$ 頂点 $0$ 辺からなるグラフを作成します。 # # ``` # graph = Dijkstra.new(n) # ``` def initialize(n : Int) @size = n.to_i32 @graph = Array.new(@size) { Array(Edge).new } end # 非負整数の重み $w$ の辺 $(u, v)$ を追加します。 # # `directed` が `true` の場合、 # 有向辺とみなして、$u$ から $v$ への辺のみ生やします。 # # ``` # graph = Dijkstra.new(n) # graph.add_edge(u, v, w) # => (u) <---w---> (v) # graph.add_edge(u, v, w, directed: true) # => (u) ----w---> (v) # ``` def add_edge(u : Int, v : Int, w : Int, directed : Bool = true) @graph[u.to_i32] << Edge.new(v.to_i32, w.to_u64) @graph[v.to_i32] << Edge.new(u.to_i32, w.to_u64) unless directed end # 全点対間の最短経路長を返します。 # # ``` # dists = graph.shortest_path # dists # => [[0, 1, 3], [1, 0, 2], [1, 1, 0]] # ``` def shortest_path (0...@size).map { |start| shortest_path(start) } end # 始点 `start` から各頂点への最短経路長を返します。 # # ``` # dist = graph.shortest_path(2) # dist # => [3, 8, 0, 7, 1] # ``` def shortest_path(start : Int) dist = [OO] * @size dist[start] = 0_i64 next_node = RadixHeap64(Int32).new next_node.push(0_u64, start.to_i32) until next_node.empty? d, source = next_node.pop next if dist[source] < d @graph[source].each do |e| next_cost = dist[source] + e.weight if next_cost < dist[e.target] dist[e.target] = next_cost next_node.push(next_cost.to_u64, e.target) end end end dist end # 始点 `start` から終点 `dest` への最短経路長を返します。 # # ``` # dist = graph.shortest_path(start: 2, dest: 0) # dist # => 12 # ``` def shortest_path(start : Int, dest : Int) shortest_path(start)[dest] end # 始点 `start` から終点 `dest` への最短経路の一例を返します。 # # ``` # route = graph.shortest_path_route(start: 2, dest: 0) # route # => [2, 7, 1, 0] # ``` def shortest_path_route(start, dest) prev = impl_memo_route(start) res = Array(Int32).new now : Int32? = dest.to_i32 until now.nil? res << now.not_nil! now = prev[now] end res.reverse end # 始点 `start` から最短路木を構築します。 # # 最短路木は `start` からの最短経路のみを残した全域木です。 # # ``` # route = graph.shortest_path_route(start: 2, dest: 0) # route # => [2, 7, 1, 0] # ``` def shortest_path_tree(start, directed : Bool = true) : Array(Array(Int32)) dist = [OO] * @size dist[start] = 0_i64 next_node = RadixHeap64(Int32).new next_node.push(0_u64, start.to_i32) birth = [-1] * @size until next_node.empty? d, source = next_node.pop next if dist[source] < d @graph[source].each do |e| next_cost = dist[source] + e.weight if next_cost < dist[e.target] dist[e.target] = next_cost next_node.push(next_cost.to_u64, e.target) birth[e.target] = source end end end tree = Array.new(@size) { [] of Int32 } @size.times do |target| source = birth[target] next if source == -1 tree[source] << target tree[target] << source unless directed end tree end # 経路復元のための「どこから移動してきたか」を # メモした配列を返します。 private def impl_memo_route(start) dist = [OO] * @size dist[start] = 0_i64 prev = Array(Int32?).new(@size) { nil } next_node = RadixHeap64(Int32).new next_node.push(0_u64, start.to_i32) until next_node.empty? d, source = next_node.pop next if dist[source] < d @graph[source].each do |e| next_cost = dist[source] + e.weight if next_cost < dist[e.target] dist[e.target] = next_cost prev[e.target] = source next_node.push(next_cost.to_u64, e.target) end end end prev end end end module NgLib class SCC alias Graph = Array(Array(Int64)) getter leader : Array(Int64) getter graph : Graph getter groups : Array(Array(Int64)) @n : Int64 @order : Array(Int64) @fwd : Graph @bwd : Graph @closed : Array(Bool) @cycles : Array(Array(Int64)) def initialize(@fwd : Graph) @n = @fwd.size.to_i64 @order = Array(Int64).new(@n) @leader = Array.new(@n, -1_i64) @bwd = Array.new(@n) { Array(Int64).new } @n.times do |i| @fwd[i].each do |j| @bwd[j] << i end end @closed = Array(Bool).new(@n, false) @n.times { |i| dfs(i) } @order = @order.reverse ptr = rdfs @graph = Array.new(ptr) { Array(Int64).new } @groups = Array.new(ptr) { Array(Int64).new } @n.times do |i| @groups[@leader[i]] << i @fwd[i].each do |j| x, y = @leader[i], @leader[j] next if x == y @graph[x] << y end end @cycles = Array(Array(Int64)).new end def same(u : Int, v : Int) leader[u] == leader[v] end def size @groups.size end def size(v : Int) @groups[leader[v]].size end # 各グループ `g` $\in$ `#groups` に対して、 # サイクルに現れる頂点をDFSの訪問順に並べたものを返します。 # # NOTE: 自己ループがある場合、サイズ 1 のサイクルが出現することに注意してください。 def cycles groups.each do |group| root = group[0] if group.size == 1 cycles << [root.to_i64] if @fwd[root].includes?(root) next end end end private def dfs(i : Int) return if @closed[i] @closed[i] = true @fwd[i].each { |j| dfs(j) } @order << i end private def rdfs ptr = 0_i64 closed = Array.new(@n, false) @order.each do |start| next if closed[start] que = Deque(Int64).new que << start closed[start] = true @leader[start] = ptr until que.empty? now = que.shift @bwd[now].each do |nxt| next if closed[nxt] closed[nxt] = true @leader[nxt] = ptr que << nxt end end ptr += 1 end ptr end end end module NgLib # 巡回セールスマン問題を解きます。 # # 内部では BitDP を用いているため、 # 頂点数が大きいグラフには対応できません。 # # 通常の巡回セールスマン問題を解きたい場合は、 # `#shortest_route(should_back : Bool = true)` を利用してください。 # # 始点を指定したいという特殊な場合は、 # `#shortest_route(start : Int, should_back : Bool = true)` を利用してください。 # # オリジナルの巡回セールスマン問題は各頂点に一度しか訪れることができません。 # 同じ頂点に複数回訪れられる場合は、`NgLib::FloydWarshall` などで、全点対最短経路長を求め、 # それを隣接行列として渡してください。 # # 計算量は $O(N^2 2^N)$ です。 class TSPGraph(T) getter size : Int32 getter mat : Array(Array(T?)) # $n$ 頂点 $0$ 辺のグラフを作ります。 # # ``` # n = 10 # NgLib::TSPGraph(Int64).new(n) # ``` def initialize(n : Int) @size = n.to_i32 @mat = Array.new(n) { Array.new(n) { nil.as(T?) } } @size.times do |i| @mat[i][i] = T.zero end end # 隣接行列に従ってグラフを作ります。 # # `nil` は辺が存在しないことを表します。 # 無限大の重みを持つ辺と捉えても良いです。 # # ``` # mat = [[0, 3, 1], [-2, 0, 4], [nil, nil, 0]] # NgLib::TSPGraph(Int32).new(mat) # ``` def initialize(@mat : Array(Array(T?))) @size = @mat.size @size.times do |i| @mat[i][i] = T.zero end end # :ditto: def initialize(matrix : Array(Array(T?) | Array(T))) @mat = matrix.map { |line| line.map { |v| v.as(T?) } } @size = @mat.size @size.times do |i| @mat[i][i] = T.zero end end # 重みが $w$ の辺 $(u, v)$ を追加します。 # # `directed` が `true` である場合、有向辺として追加します。 # # ``` # n, m = read_line.split.map &.to_i # graph = NgLib::TSPGraph.new(n) # m.times do # u, v, w = read_line.split.map &.to_i64 # u -= 1; v -= 1 # 0-index # graph.add_edge(u, v, w, directed: true) # end # ``` def add_edge(u : Int, v : Int, w : T, directed : Bool = true) uv = @mat[u][v] if uv.nil? @mat[u][v] = w else @mat[u][v] = {uv, w}.min end unless directed vu = @mat[v][u] if vu.nil? @mat[v][u] = w else @mat[v][u] = {vu, w}.min end end end # `dp[S][i] := 今まで訪問した頂点の集合が S で、最後に訪れた頂点が i であるときの最小経路長` を # 返します。 # # `should_back` が `true` なら、始点の頂点に戻ってこない場合の最小経路長を計算します。 # また、任意の始点に対しての答えを求める点に注意してください。 # # `should_back` が `false` なら通常の巡回セールスマン問題の答えです。 # 始点が頂点 $0$ であることに注意してください。 # つまり、`dp[(1 << n) - 1][0]` が答えです。 # # どのような順でも到達できない場合は `nil` が格納されます。 # # ``` # graph = TSPGraph(Int64).new(n) # dist = graph.shortest_route # dist[(1 << n) - 1][0] # => ans # dist.last.first # => ans # ``` def shortest_route(should_back : Bool = true) dp = Array.new(1 << @size) { Array.new(@size) { nil.as(T?) } } if should_back dp[0][0] = T.zero else @size.times do |i| dp[1 << i][i] = T.zero end end calc(dp) end # `dp[S][i] := 今まで訪問した頂点の集合が S で、最後に訪れた頂点が i であるときの最小経路長` を # 返します。 # # `should_back` が `true` なら、始点の頂点に戻ってこない場合の最小経路長を計算します。 # また、始点が `start` であることに注意してください。 # # `should_back` が `false` なら通常の巡回セールスマン問題の答えです。 # 始点が頂点 `start` であることに注意してください。 # つまり、`dp[(1 << n) - 1][start]` が答えです。 # # どのような順でも到達できない場合は `nil` が格納されます。 # # ``` # graph = TSPGraph(Int64).new(n) # dist = graph.shortest_route(start: 2) # dist[(1 << n) - 1][2] # ``` def shortest_route(start : Int, should_back : Bool = true) dp = Array.new(1 << @size) { Array.new(@size) { nil.as(T?) } } if should_back dp[0][start] = T.zero else dp[1 << start][start] = T.zero end calc(dp) end private def calc(dp : Array(Array(T?))) dist = @mat (1 << @size).times do |visited| @size.times do |dest| @size.times do |from| next if visited != 0 && visited.bit(from) == 0 next if visited.bit(dest) == 1 now = dp[visited][from] d = dist[from][dest] next if now.nil? next if d.nil? nxt = dp[visited | (1 << dest)][dest] if nxt.nil? || nxt > now + d dp[visited | (1 << dest)][dest] = now + d end end end end dp end end end module NgLib struct Edge @data : {Int64, Int64} def initialize(t : Int64, w : Int64) @data = {t, w} end def self.to @data[0] end def self.weight @data[1] end def [](i : Int) @data[i] end end end require "atcoder/priority_queue" struct Int def self.bar -1 end end struct Char def self.bar '#' end end module NgLib class Grid(T) class UnreachableError < Exception end include Enumerable(T) def self.add(v1 : {Int, Int}, v2 : {Int, Int}) {v1[0] + v2[0], v1[1] + v2[1]} end def self.sub(v1 : {Int, Int}, v2 : {Int, Int}) {v1[0] - v2[0], v1[1] - v2[1]} end UP = {-1, 0} LEFT = {0, -1} DOWN = {1, 0} RIGHT = {0, 1} DYDX2 = [DOWN, RIGHT] DYDX4 = [UP, LEFT, DOWN, RIGHT] DYDX8 = [ UP, add(UP, RIGHT), RIGHT, add(DOWN, RIGHT), DOWN, add(DOWN, LEFT), LEFT, add(UP, LEFT), ] alias Pos = {Int32, Int32} getter h : Int32, w : Int32 getter delta : Array(Pos) @s : Array(T) @bar : T def self.dydx2(s : Array(Array(T))) new(s, DYDX2) end def self.dydx2(height : Int, width : Int) new(height, width, DYDX2) end def self.dydx2(height : Int, &) new(height, DYDX2) { |line| yield line } end def self.dydx4(s : Array(Array(T))) new(s, DYDX4) end def self.dydx4(height : Int, width : Int, &) new(height, width, DYDX4) { |i, j| yield i, j } end def self.dydx4(height : Int, &) new(height, DYDX4) { |line| yield line } end def self.dydx8(s : Array(Array(T))) new(s, DYDX8) end def self.dydx8(height : Int, width : Int, &) new(height, width, DYDX8) { |i, j| yield i, j } end def self.dydx8(height : Int, &) new(height, DYDX8) { |line| yield line } end def initialize(s : Array(Array(T)), @delta) @h = s.size @w = s[0].size @s = s.flatten @bar = T.bar end def initialize(h : Int, @delta, &) @h = h.to_i @w = -1 @s = Array(Array(T)).new(h) { |line| t = (yield line); @w = t.size; t }.flatten raise "@w is null" if @w == -1 @bar = T.bar end def initialize(h : Int, w : Int, @delta, &) @h = h.to_i @w = w.to_i @s = Array(T).new(h * w) { |x| yield x // w, x % w } @bar = T.bar end # 位置 `pos` に対して次の座標をタプルで返します。 # # ここで「次」とは、each_with_coord で走査するときの順と同様です。 # # 次が存在しない場合は `nil` を返します。 # # ``` # grid.h, grid.w # => 3, 4 # grid.next_coord({1, 2}) # => {1, 3} # grid.next_coord({1, 3}) # => {2, 0} # grid.next_coord({2, 3}) # => nil # ``` def next_coord(pos) j = (pos[1] + 1) % @w i = pos[0] + (j == 0 ? 1 : 0) i >= @h ? nil : {i, j} end # 位置 `pos` に対して次の座標をタプルで返します。 # # ここで「次」とは、each_with_coord で走査するときの順と同様です。 # # 次が存在しない場合はエラーを送出します。 # # ``` # grid.h, grid.w # => 3, 4 # grid.next_coord({1, 2}) # => {1, 3} # grid.next_coord({1, 3}) # => {2, 0} # grid.next_coord({2, 3}) # => nil # ``` def next_coord!(pos) next_coord(pos) || raise Exception.new end # 位置 `pos` がグリッドの範囲外なら `true` を返します。 # # ``` # grid.over?({-1, 0}) # => true # grid.over?({h + 10, w + 10}) # => true # grid.over?({0, 0}) # => false # ``` @[AlwaysInline] def over?(pos) : Bool over?(pos[0], pos[1]) end # 位置 $(y, x)$ がグリッドの範囲外なら `true` を返します。 # # ``` # grid.over?(-1, 0) # => true # grid.over?(h + 10, w + 10) # => true # grid.over?(0, 0) # => false # ``` @[AlwaysInline] def over?(y, x) : Bool y < 0 || y >= @h || x < 0 || x >= @w end # 位置 `pos` が進入禁止なら `true` を返します。 # # ``` # s = [ # "..".chars, # ".#".chars, # ] # # grid.barred?({0, 0}) # => false # grid.barred?({1, 1}) # => true # ``` @[AlwaysInline] def barred?(pos) : Bool barred?(pos[0], pos[1]) end # 位置 $(y, x)$ が進入禁止なら `true` を返します。 # # ``` # s = [ # "..".chars, # ".#".chars, # ] # # grid.barred?(0, 0) # => false # grid.barred?(1, 1) # => true # ``` @[AlwaysInline] def barred?(y : Int, x : Int) : Bool over?(y, x) || self[y, x] == @bar end # 位置 `pos` が通行可能なら `true` を返します。 # # ``` # s = [ # "..".chars, # ".#".chars, # ] # # grid.free?({0, 0}) # => true # grid.free?({1, 1}) # => false # ``` @[AlwaysInline] def free?(pos) : Bool !barred?(pos) end # 位置 $(y, x)$ が通行可能なら `true` を返します。 # # ``` # s = [ # "..".chars, # ".#".chars, # ] # # grid.free?(0, 0) # => true # grid.free?(1, 1) # => false # ``` @[AlwaysInline] def free?(y : Int, x : Int) : Bool !barred?(y, x) end def simulate(si : Int, sj : Int, directions : Enumerable, iterations : Enumerable) : {Int32, Int32} lwalls = self.line_walls cwalls = self.column_walls now_i, now_j = si.to_i, sj.to_i directions.zip(iterations) do |dir, iter| case dir when 'L' walls = lwalls[now_i] pos = (walls.bsearch_index { |x| x >= now_j } || walls.size) - 1 next_j = walls[pos] + 1 now_j = {now_j - iter, next_j}.max when 'R' walls = lwalls[now_i] pos = walls.bsearch_index { |x| x > now_j } next_j = pos ? walls[pos] - 1 : @w - 1 now_j = {now_j + iter, next_j}.min when 'U' walls = cwalls[now_j] pos = (walls.bsearch_index { |x| x >= now_i } || walls.size) - 1 next_i = (pos >= 0 ? walls[pos] : -1) + 1 now_i = {now_i - iter, next_i}.max when 'D' walls = cwalls[now_j] pos = walls.bsearch_index { |x| x > now_i } next_i = pos ? walls[pos] - 1 : @h - 1 now_i = {now_i + iter, next_i}.min end end {now_i, now_j} end def simulate(si : Int, sj : Int, directions : Enumerable) : {Int32, Int32} simulate(si, sj, directions, [1] * directions.size) end def line_walls : Array(Array(Int32)) walls = Array.new(@h) { [] of Int32 } @h.times do |i| walls[i] << -1 @w.times do |j| walls[i] << j if barred?(i, j) end walls[i] << @w end walls end def column_walls : Array(Array(Int32)) walls = Array.new(@w) { [] of Int32 } @w.times do |j| walls[j] << -1 @h.times do |i| walls[j] << i if barred?(i, j) end walls[j] << @h end walls end # 全マス間の最短経路長を返します。 # # 到達できない場合は `nil` が格納されます。 # # ``` # dist = grid.shortest_path # dist[si][sj][gi][gj] # => 4 # ``` def shortest_path : Array(Array(Array(Array(Int64?)))) Array.new(@h) { |start_i| Array.new(@w) { |start_j| shortest_path(start_i, start_j) } } end # 始点 $(s_i, s_j)$ から各マスへの最短経路長を返します。 # # 到達できない場合は `nil` が格納されます。 # # ``` # dist = grid.shortest_path(start: {si, sj}) # dist[gi][gj] # => 4 # ``` def shortest_path(start : Tuple) : Array(Array(Int64?)) queue = Deque.new([start]) dist = Array.new(@h) { Array.new(@w) { nil.as(Int64?) } } dist[start[0]][start[1]] = 0 until queue.empty? i, j = queue.shift d = dist[i][j] || raise NilAssertionError.new each_neighbor(i, j) do |i_adj, j_adj| next unless dist[i_adj][j_adj].nil? dist[i_adj][j_adj] = d + 1 queue << {i_adj, j_adj} end end dist end # :ditto: def shortest_path(start_i : Int, start_j : Int) : Array(Array(Int64?)) shortest_path({start_i, start_j}) end # 始点 $(s_i, s_j)$ から終点 $(g_i, g_j)$ への最短経路長を返します。 # # ``` # grid.shortest_path(start: {si, sj}, dest: {gi, gj}) # => 4 # ``` def shortest_path(start : Tuple, dest : Tuple) : Int64? shortest_path(start)[dest[0]][dest[1]] end # 始点 $(s_i, s_j)$ から終点 $(g_i, g_j)$ への最短経路長を返します。 # # ``` # grid.shortest_path(start: {si, sj}, dest: {gi, gj}) # => 4 # ``` def shortest_path!(start : Tuple, dest : Tuple) : Int64 shortest_path(start)[dest[0]][dest[1]] || raise UnreachableError.new end # 全マス間の最短経路長を返します。 # # 内部で利用するアルゴリズムをタグで指定します。 # - `:bfs` : 侵入不可能な場合は U::MAX を返してください。 # - `:binary_bfs` : 重みは $0$ または $1$ である必要があります。 # - `:dijkstra` : デフォルト値です。$\infty = U::MAX$ です。負の数には気をつけてください。 # # $(i, j)$ から $(i', j')$ への移動時の重みをブロックで指定します。 # # ``` # dist = grid.shortest_path { |i, j, i_adj, j_adj| f(i, j, i_adj, j_adj) } # dist[si][sj][gi][gj] # => 4 # ``` def shortest_path(tag = :dijkstra, & : Int32, Int32, Int32, Int32 -> U) : Array(Array(Int64)) forall U Array.new(@h) { |start_i| Array.new(@w) { |start_j| shortest_path(start_i, start_j) { |i, j, i_adj, j_adj| yield i, j, i_adj, j_adj } } } end # 始点 $(s_i, s_j)$ から各マスへの最短経路長を返します。 # # 内部で利用するアルゴリズムをタグで指定します。 # - `:bfs` : 侵入不可能な場合は U::MAX を返してください。 # - `:binary_bfs` : 重みは $0$ または $1$ である必要があります。 # - `:dijkstra` : デフォルト値です。$\infty = U::MAX$ です。負の数には気をつけてください。 # # $(i, j)$ から $(i', j')$ への移動時の重みをブロックで指定します。 # # ``` # dist = grid.shortest_path(start: {0, 0}) { |i, j, i_adj, j_adj| f(i, j, i_adj, j_adj) } # dist[gi][gj] # => 4 # ``` # ameba:disable Metrics/CyclomaticComplexity def shortest_path(start : Tuple, tag = :dijkstra, & : Int32, Int32, Int32, Int32 -> U) : Array(Array(U)) forall U case tag when :bfs next_node = Deque({Int32, Int32}).new([start]) dist = Array.new(@h) { Array.new(@w) { U::MAX } } dist[start[0]][start[1]] = U.zero until next_node.empty? i, j = next_node.shift each_neighbor(i, j) do |i_adj, j_adj| weight = yield i.to_i32, j.to_i32, i_adj.to_i32, j_adj.to_i32 raise "Weight error" unless weight == U.zero.succ || weight == U::MAX next if weight == U::MAX next if dist[i_adj][j_adj] != U::MAX dist[i_adj][j_adj] = dist[i][j] + U.zero.succ next_node << {i_adj, j_adj} end end return dist when :binary_bfs next_node = Deque({Int32, Int32}).new([start]) dist = Array.new(@h) { Array.new(@w) { U::MAX } } dist[start[0]][start[1]] = U.zero until next_node.empty? i, j = next_node.shift each_neighbor(i, j) do |i_adj, j_adj| weight = yield i.to_i32, j.to_i32, i_adj.to_i32, j_adj.to_i32 raise "Weight error" unless weight.in?({U.zero, U.zero.succ}) next_cost = dist[i][j] <= U::MAX - weight ? dist[i][j] + weight : U::MAX if next_cost < dist[i_adj][j_adj] dist[i_adj][j_adj] = next_cost if weight == 0 next_node.unshift({i_adj.to_i32, j_adj.to_i32}) else next_node << {i_adj.to_i32, j_adj.to_i32} end end end end return dist when :dijkstra next_node = AtCoder::PriorityQueue.new([{U.zero, start}]) dist = Array.new(@h) { Array.new(@w) { U::MAX } } dist[start[0]][start[1]] = U.zero until next_node.empty? d, pos = next_node.pop.not_nil! i, j = pos next if dist[i][j] < d each_neighbor(i, j) do |i_adj, j_adj| weight = yield i.to_i32, j.to_i32, i_adj.to_i32, j_adj.to_i32 next_cost = dist[i][j] <= U::MAX - weight ? dist[i][j] + weight : U::MAX if next_cost < dist[i_adj][j_adj] dist[i_adj][j_adj] = next_cost next_node << {next_cost, {i_adj.to_i32, j_adj.to_i32}} end end end return dist end raise "Tag Error" end # 始点 $(s_i, s_j)$ から終点 $(g_i, g_j)$ への最短経路長を返します。 # # 内部で利用するアルゴリズムをタグで指定します。 # - `:bfs` : 侵入不可能な場合は U::MAX を返してください。 # - `:binary_bfs` : 重みは $0$ または $1$ である必要があります。 # - `:dijkstra` : デフォルト値です。$\infty = U::MAX$ です。負の数には気をつけてください。 # # $(i, j)$ から $(i', j')$ への移動時の重みをブロックで指定します。 # # ``` # grid.shortest_path(start: {si, sj}, dest: {gi, gj}) { |i, j, i_adj, j_adj| # f(i, j, i_adj, j_adj) # } # => 4 # ``` def shortest_path(start : Tuple, dest : Tuple, tag = :dijkstra, & : Int32, Int32, Int32, Int32 -> U) : Int64 forall U shortest_path(start, tag) { |i, j, i_adj, j_adj| yield i, j, i_adj, j_adj }[dest[0]][dest[1]] end # グリッドを隣接リスト形式で無向グラフに変換します。 # # あるマス $(i, j)$ の頂点番号は $Wi + j$ となります。 # # - `:connect_free` : free 同士を結びます(デフォルト) # - `:connect_bar` : bar 同士を結びます # - `:connect_same_type` : bar 同士、free 同士を結びます # # ``` # s = [ # "..#".chars, # ".#.".chars, # "##.".chars, # ] # grid = Grid(Char).dydx4(s) # grid.to_graph # => [[3, 1], [0], [], [0], [], [8], [], [], [5]] # ``` def to_graph(type = :connect_free) : Array(Array(Int32)) graph = Array.new(@w * @h) { [] of Int32 } @h.times do |i| @w.times do |j| node = @w * i + j @delta.each do |(di, dj)| i_adj = i + di j_adj = j + dj next if over?(i_adj, j_adj) node2 = @w * i_adj + j_adj both_frees = free?(i, j) & free?(i_adj, j_adj) both_bars = barred?(i, j) & barred?(i_adj, j_adj) case type when :connect_free graph[node] << node2 if both_frees when :connect_bar graph[node] << node2 if both_bars when :connect_same_type graph[node] << node2 if both_frees || both_bars end end end end graph end # 連結する free および bar を塗り分けたグリッドを返します。 # free のマスは非負整数の連番でラベル付けされ、bar は負の連番でラベル付けされます。 # `label_grid.max` は `(島の数 - 1)` を返すことに注意してください。 # # ``` # s = [ # "..#".chars, # ".#.".chars, # "##.".chars, # ] # grid = Grid(Char).dydx4(s) # grid.label_grid # => [[0, 0, -1], [0, -2, 1], [-2, -2, 1]] # ``` def label_grid table = Array.new(@h) { [nil.as(Int32?)] * @w } free_index, bar_index = 0, -1 @h.times do |i| @w.times do |j| next unless table[i][j].nil? label = 0 is_bar = barred?(i, j) if is_bar label = bar_index bar_index -= 1 else label = free_index free_index += 1 end queue = Deque({Int32, Int32}).new([{i, j}]) table[i][j] = label until queue.empty? y, x = queue.shift @delta.each do |(dy, dx)| ny = y + dy nx = x + dx next if over?(ny, nx) next unless table[ny][nx].nil? next if is_bar ^ barred?(ny, nx) table[ny][nx] = label queue << {ny, nx} end end end end Grid(Int32).new(table.map { |line| line.map(&.not_nil!) }, @delta) end # グリッドの値を $(0, 0)$ から $(H, W)$ まで順に列挙します。 # # ``` # s = [ # "..#".chars, # ".#.".chars, # "##.".chars, # ] # grid = Grid(Char).dydx4(s) # gird.each { |c| puts c } # => '.', '.', '#', '.', ..., '.' # ``` def each(& : T ->) i = 0 while i < h j = 0 while j < w yield self[i, j] j += 1 end i += 1 end end # グリッドの値を $(0, 0)$ から $(H, W)$ まで順に列挙します。 # # index は $Wi + j$ を返します。通常は `each_with_coord` を利用することを推奨します。 def each_with_index(&) i = 0 while i < h j = 0 while j < w yield self[i, j], w*i + j j += 1 end i += 1 end end # グリッドの値を $(0, 0)$ から $(H, W)$ まで順に座標付きで列挙します。 # # ``` # s = [ # "..#".chars, # ".#.".chars, # "##.".chars, # ] # grid = Grid(Char).new(s) # gird.each { |c, (i, j)| puts c, {i, j} } # ``` def each_with_coord(&) i = 0 while i < h j = 0 while j < w yield self[i, j], {i, j} j += 1 end i += 1 end end # グリッドの各要素に対して、ブロックを実行した結果に変換したグリッドを返します。 def map(& : T -> U) : Grid(U) forall U ret = Array.new(h) { Array(U).new(w) } i = 0 while i < h j = 0 line = Array(U).new(w) while j < w line << yield self[i, j] j += 1 end ret[i] = line i += 1 end Grid(U).new(ret, @delta) end # グリッドの各要素に対して、ブロックを実行した結果に変換したグリッドを返します。 def map_with_coord(& : T, {Int32, Int32} -> U) : Grid(U) forall U ret = Array.new(h) { Array(U).new(w) } i = 0 while i < h j = 0 line = Array(U).new(w) while j < w line << yield self[i, j], {i, j} end ret[i] = line end Grid(U).new(ret, @delta) end def index(offset = {0, 0}, & : T ->) : {Int32, Int32}? i, j = offset while i < @h while j < @w return {i, j} if yield self[i, j] j += 1 end j = 0 i += 1 end nil end def index(obj, offset = {0, 0}) : {Int32, Int32}? index(offset) { |elem| elem == obj } end def index!(offset = {0, 0}, & : T ->) : {Int32, Int32} index(offset) { |elem| yield elem } || raise Exception.new("Not found.") end def index!(obj, offset = {0, 0}) : {Int32, Int32}? index!(offset) { |elem| elem == obj } end # 位置 $(y, x)$ の近傍で、侵入可能な位置を列挙します。 # # ``` # grid = Grid.dydx([".#.", "...", "..."]) # # grid.each_neighbor(1, 1) do |ny, nx| # end # ``` def each_neighbor(y : Int, x : Int, &) i = 0 while i < @delta.size ny = y + @delta[i][0] nx = x + @delta[i][1] yield ny, nx if free?(ny, nx) i += 1 end end # 位置 $(y, x)$ の近傍で、侵入可能な位置を方向とともに列挙します。 # # ``` # grid = Grid.dydx4([".#.", "...", "..."]) # # grid.each_neighbor(1, 1) do |ny, nx, dir| # end # ``` def each_neighbor_with_direction(y : Int, x : Int, &) i = 0 while i < @delta.size ny = y + @delta[i][0] nx = x + @delta[i][1] yield ny, nx, i if free?(ny, nx) i += 1 end end def node_index(y : Int, x : Int) y * @w + x end def fetch(y : Int, x : Int, default : T) over?(y, x) ? default : self[y, x] end def to_a : Array(Array(T)) a = Array.new(@h) { Array(T).new(@w) } @h.times do |i| @w.times do |j| a[i] << self[i, j] end end a end def to_s(io : IO) @h.times do |i| @w.times do |j| io << ' ' if j != 0 io << self[i, j] end io << '\n' end io end def [](pos : {Int, Int}) self[pos[0], pos[1]] end def [](y : Int, x : Int) @s[y*@w + x] end def []=(pos : {Int, Int}, c : T) self[pos[0], pos[1]] = c end def []=(y : Int, x : Int, c : T) @s[y*@w + x] = c end end end module NgLib # 不定方程式 $ax + by = c$ を解きます。 # # 常に解が存在するわけではないので、`#has_solution?` などで解が存在するかを確かめるようにしてください。 # # 多分、解の媒介変数 $m = 0$ のとき、$|x| + |y|$ が最小になります。 class LDE(T) class NotHasSolutionError < Exception; end @a : T @b : T @c : T @check : Bool @x0 : T @y0 : T @a2 : T @b2 : T @m : T # 不定方程式 $ax + by = c$ を作ります。 def initialize(a, b, c) @a, @b, @c = T.new(a), T.new(b), T.new(c) @m = T.zero x = [T.zero] y = [T.zero] @check = true g = @a.gcd(@b) if @c % g != 0 @x0 = @y0 = @a2 = @b2 = T.zero @check = false else extgcd(@a.abs, @b.abs, x, y) x[0] = -x[0] if @a < 0 y[0] = -y[0] if @b < 0 x[0] *= @c // g y[0] *= @c // g @x0 = x[0] @y0 = y[0] @a2 = -@a // g @b2 = @b // g end end # 現在の $m$ の値に対する $x$ の解を返します。 # # 解は媒介変数 $m$ を用いて $x = x_0 + mk,\ y = y_0 + mh$ と求まるので、 # この $x$ を返します。 # # 解が存在しない場合は `nil` を返します。 def x : T? @check ? @x0 : nil end # 現在の $m$ の値に対する $x$ の解を返します。 # # 解は媒介変数 $m$ を用いて $x = x_0 + mk,\ y = y_0 + mh$ と求まるので、 # この $x$ を返します。 # # 解が存在しない場合は例外を送出します。 def x! : T x || raise NotHasSolutionError.new end # 現在の $m$ の値に対する $y$ の解を返します。 # # 解は媒介変数 $m$ を用いて $x = x_0 + mk,\ y = y_0 + mh$ と求まるので、 # この $y$ を返します。 # # 解が存在しない場合は `nil` を返します。 def y : T? @check ? @y0 : nil end # 現在の $m$ の値に対する $y$ の解を返します。 # # 解は媒介変数 $m$ を用いて $x = x_0 + mk,\ y = y_0 + mh$ と求まるので、 # この $y$ を返します。 # # 解が存在しない場合は例外を送出します。 def y! : T y || raise NotHasSolutionError.new end # 解は媒介変数 $m$ を用いて $x = x_0 + mk,\ y = y_0 + mh$ と求まるので、 # この $k$ を返します。 # # 解が存在しない場合は `nil` を返します。 def k : T? @check ? @b2 : nil end # 現在の $m$ の値に対する $k$ の解を返します。 # # 解は媒介変数 $m$ を用いて $x = x_0 + mk,\ y = y_0 + mh$ と求まるので、 # この $k$ を返します。 # # 解が存在しない場合は例外を送出します。 def k! : T k || raise NotHasSolutionError.new end # 解は媒介変数 $m$ を用いて $x = x_0 + mk,\ y = y_0 + mh$ と求まるので、 # この $h$ を返します。 # # 解が存在しない場合は `nil` を返します。 def h : T? @check ? @a2 : nil end # 現在の $m$ の値に対する $h$ の解を返します。 # # 解は媒介変数 $m$ を用いて $x = x_0 + mk,\ y = y_0 + mh$ と求まるので、 # この $h$ を返します。 # # 解が存在しない場合は例外を送出します。 def h! : T h || raise NotHasSolutionError.new end # 現在の $m$ の値に対する解を返します。 # # 解は媒介変数 $m$ を用いて $x = x_0 + mk,\ y = y_0 + mh$ と求まるので、 # $(x_0, b', y_0, a')$ をこの順に格納したタプルとして返します。 # # 解が存在しない場合は `nil` を返します。 def solution : {T, T, T, T}? @check ? {@x0, @b2, @y0, @a2} : nil end # 現在の $m$ の値に対する解を返します。 # # 解が存在しない場合は例外を送出します。 def solution! : {T, T, T, T} solution || raise NotHasSolutionError.new end # 解が存在するかを返します。 def has_solution? @check end # 媒介変数 $m$ の値を返します。 def m @m end # 媒介変数 $m$ の値を更新します。 def m=(m) @x0 += (-(@m - m)) * @b2 @y0 += (-(@m - m)) * @a2 @m = T.new(m) end def to_s(io : IO) io << @a << "x" << (@b < 0 ? " - " : " + ") << @b.abs << "y = " << @c << " # => x = " << x << ", y = " << y end def inspect(io : IO) to_s(io) end private def extgcd(a, b, x, y) if b == 0 x[0], y[0] = T.new(1), T.new(0) return a end d = extgcd(b, a % b, y, x) y[0] -= (a // b) * x[0] d end end end module NgLib # 文字列 $S$ に対して、上手なハッシュを作ることで、比較やLCPを高速に求めます。 class RollingHash MOD = (1_u64 << 61) - 1 getter size : Int32 @base : UInt64 @power : Array(UInt64) @hash : Array(UInt64) @selfhash : UInt64 # 配列 $a$ に対する、基数が `base` のロリハを構築します。 # # `base` は指定しない場合、ランダムに生成されます。 # # ``` # rh = RollingHash.new([1, 2, 5, 1, 2]) # ``` def initialize(a : Array(Int), base : UInt64? = nil) initialize(a.size, a, base) end # 文字列 $s$ に対する、基数が `base` のロリハを構築します。 # # `base` は指定しない場合、ランダムに生成されます。 # # ``` # rh = RollingHash.new("missisippi") # ``` def initialize(s : String, base : UInt64? = nil) initialize(s.size, s.bytes, base) end # Enumerable な列 $a$ に対する、基数が base のロリハを構築します。 # # `base` は指定しない場合、ランダムに生成されます。 # # ``` # rh = RollingHash.new(5, [1, 2, 5, 1, 2]) # ``` def initialize(@size, a : Enumerable, base : UInt64? = nil) base = RollingHash.create_base if base.nil? @base = base.not_nil! @power = [1_u64] * (@size + 1) @hash = [0_u64] * (@size + 1) a.each_with_index do |x, i| @power[i + 1] = mul(@power[i], @base) @hash[i + 1] = mul(@hash[i], @base) + x.to_u64 @hash[i + 1] -= MOD if @hash[i + 1] >= MOD end @selfhash = hash(a) end # ランダムに基底を生成します。 # # ``` # base = RollingHash.create_base # base # => 1729 # ``` def self.create_base rand(628_u64..MOD - 2) end # 初期化時に使用した列に対するハッシュを返します。 # # ``` # rh = RollingHash.new("missisippi") # rh.hash # => 339225237399054811 # rh.hash("missisippi") # => 339225237399054811 # ``` def hash : UInt64 @selfhash end # 文字列 $s$ のハッシュを返します。 # # ``` # rh = RollingHash.new("missisippi") # rh.hash("is") # => 339225237399054811 # rh.hash("abc") # => 496222201481864933 # ``` def hash(s : String) hash(s.bytes) end # 列 $s$ のハッシュを返します。 # # ``` # rh = RollingHash.new("missisippi") # rh.hash("is") # => 339225237399054811 # rh.hash("abc") # => 496222201481864933 # ``` def hash(s : Enumerable) s.reduce(0_u64) { |acc, elem| mul(acc, @base) + elem.to_u64 } end # `s[start...start + length]` のハッシュを返します。 # # ``` # rh = RollingHash.new("missisippi") # rh.substr(4, length: 2) # => 339225237399054811 # rh.substr(5, length: 2) # => 339225237399054811 # ``` def substr(start : Int, length : Int) : UInt64 res = @hash[start + length] + MOD - mul(@hash[start], @power[length]) res < MOD ? res : res - MOD end # `range` で指定した範囲 `s[range]` のハッシュを返します。 # # ``` # rh = RollingHash.new("missisippi") # rh.slice(4..5) # => 339225237399054811 # rh.slice(5..6) # => 339225237399054811 # ``` def slice(range : Range(Int?, Int?)) : UInt64 left = (range.begin || 0) right = if range.end.nil? @size else range.end.not_nil! + (range.exclusive? ? 0 : 1) end length = right - left substr(start: left, length: length) end # `range` で指定した範囲 `s[range]` のハッシュを返します。 # # ``` # rh = RollingHash.new("missisippi") # rh[4..5] # => 339225237399054811 # rh[5..6] # => 339225237399054811 # ``` def [](range : Range(Int?, Int?)) : UInt64 slice(range) end # ハッシュ値 $h_1$ とハッシュ値 $h_2$ を結合したときのハッシュ値を返します。 # # ハッシュ値 $h_2$ の元々の長さを渡す必要があります。 # # ``` # rh = RollingHash.new("missisippi") # h1 = rh[1..2] # "is" # h2 = rh[5..6] # "si" # h = rh.concat(h1, h2, h2_len: 2) # h == rh.[1..4] # => true # ``` def concat(h1 : UInt64, h2 : UInt64, h2_len : Int) : UInt64 res = mul(h1, @power[h2_len]) + h2 res < MOD ? res : res - MOD end # `s[i...]` と `other[j...]` の最長共通接頭辞の長さを返します。 # # `other` はデフォルトで自分自身を渡しています。 # 自分自身以外を渡す場合は $(mod, base)$ が一致している必要があります。 # # ``` # rh1 = RollingHash.new("missisippi") # rh1 = rh1.lcp(3, 5) # => 2 # rh1 = rh1.lcp(0, 1) # => 0 # ``` def lcp(i : Int, j : Int, other = self) : Int32 length = Math.min(@hash.size - i, @hash.size - j) ok = length - (1..length).bsearch { |len| l = length - len self.substr(start: i, length: l) == other.substr(start: j, length: l) }.not_nil! ok.to_i32 end # `s[i...]` と `t[j...]` の最長共通接頭辞の長さを返します。 # # $i, j$ はデフォルトで $0$ を渡しています。 # # ``` # rh1 = RollingHash.new("missisippi", base: 628) # rh2 = RollingHash.new("thisisapen", base: 628) # RollingHash.lcp(rh1, rh2) # => 0 # RollingHash.lcp(rh1, rh2, 4, 2) # => 3 # ``` def self.lcp(rh1 : self, rh2 : self, i : Int = 0, j : Int = 0) : Int32 rh1.lcp(i, j, rh2) end # 文字列検索を行います。 # # `s[offset..]` から `t` と一致する初めての添字を返します。 # 添字は `s` が基準です。また、`offset` が加算された値が返ります。 # # 存在しない場合は `nil` を返します。 # # ``` # rh = RollingHash.new("missisippi", base: 628) # rh.index("is") # => 1 # rh.index("is", offset: 4) # => 4 # rh.index("mid") # => nil # rh.index("i") # => 1 # rh.index("pi") # => 8 # ``` def index(t : String, offset : Int = 0) : Int32? index(t.bytes, offset) end # 検索を行います。 # # `s[offset..]` から `t` と一致する初めての添字を返します。 # 添字は `s` が基準です。また、`offset` が加算された値が返ります。 # # 存在しない場合は `nil` を返します。 # # ``` # rh = RollingHash.new("missisippi", base: 628) # rh.index("is") # => 1 # rh.index("is", offset: 4) # => 4 # rh.index("mid") # => nil # rh.index("i") # => 1 # rh.index("pi") # => 8 # ``` def index(t : Enumerable, offset : Int = 0) : Int32? ths = hash(t) t_len = t.size res = (offset..@size - t.size).index { |i| ths == substr(i, t_len) } res ? res.not_nil! + offset : nil end # 文字列検索を行います。 # # `s[offset..]` から `t` と一致する初めての添字を返します。 # 添字は `s` が基準です。また、`offset` が加算された値が返ります。 # # 存在しない場合は例外を投げます。 # # ``` # rh = RollingHash.new("missisippi", base: 628) # rh.index!("is") # => 1 # rh.index!("is", offset: 4) # => 4 # rh.index!("mid") # => Enumerable::NotFoundError # rh.index!("i") # => 1 # rh.index!("pi") # => 8 # ``` def index!(t : String, offset : Int = 0) : Int32 index!(t.bytes, offset) end # 検索を行います。 # # `s[offset..]` から `t` と一致する初めての添字を返します。 # 添字は `s` が基準です。また、`offset` が加算された値が返ります。 # # 存在しない場合は例外を投げます。 # # ``` # rh = RollingHash.new("missisippi", base: 628) # rh.index!("is") # => 1 # rh.index!("is", offset: 4) # => 4 # rh.index!("mid") # => Enumerable::NotFoundError # rh.index!("i") # => 1 # rh.index!("pi") # => 8 # ``` def index!(t : Enumerable, offset : Int = 0) : Int32 ths = 0_u64 t.each { |elem| ths = mul(ths, @base) + elem.to_u64 } t_len = t.size (offset..@size - t.size).index! { |i| ths == substr(i, t_len) } + offset end @[AlwaysInline] private def mul(a : UInt64, b : UInt64) : UInt64 t = a.to_u128 * b t = (t >> 61) + (t & MOD) (t < MOD ? t : t - MOD).to_u64 end end end module NgLib module FastIn extend self lib LibC fun getchar = getchar_unlocked : Char end module Scanner extend self {% for int_t in Int::Primitive.union_types %} {% if Int::Signed.union_types.includes?(int_t) %} def read_{{ int_t.name.downcase[0..0] }}{{ int_t.name.downcase[3...int_t.name.size] }}(offset = 0) {% else %} def read_{{ int_t.name.downcase[0..0] }}{{ int_t.name.downcase[4...int_t.name.size] }}(offset = 0) {% end %} c = next_char res = {{ int_t }}.new(c.ord - '0'.ord) sgn = 1 case c when '-' res = {{ int_t }}.new(LibC.getchar.ord - '0'.ord) sgn = -1 when '+' res = {{ int_t }}.new(LibC.getchar.ord - '0'.ord) end until ascii_voidchar?(c = LibC.getchar) res = res*10 + (c.ord - '0'.ord) end res * sgn + offset end {% end %} def read_char : Char next_char end def read_word : String c = next_char s = [c] until ascii_voidchar?(c = LibC.getchar) s << c end s.join end private def next_char : Char c = '_' while ascii_voidchar?(c = LibC.getchar) end c end private def ascii_voidchar?(c) c.ascii_whitespace? || c.ord == -1 end end end end # require "../constants" macro chmax(a, b); ({{a}} < {{b}} && ({{a}} = {{b}})) end macro chmin(a, b); ({{a}} > {{b}} && ({{a}} = {{b}})) end