tookunn’s diary

主に競技プログラミング関係

AtCoderBeginnerContest 050 D

問題

abc050.contest.atcoder.jp

解説見て通しました。

考察

これ以下の記述は自分が公式解説放送見ながら書いたメモみたいなものです。

a+b = x(xは問題文でのv)
a \oplus b = y(yは問題文でのu)
以上の式がある。
a_i(aのiビット目)とb_i(bのiビット目)に注目して考えると、a_ib_iを入れ替えてもa,b,x,yは変わらず同じ値になる。

ただし、(a_i,b_i) = (1,0)or(0,1)の場合はx,yの値は変わらないがa,bの値が変わってしまう。これだと(x,y)の組に対して重複して(a,b)を考えてしまうので、(a_i,b_i) = (1,0)または(0,1)に固定してしまう。(公式解説放送では(a_i,b_i) = (1,0)に固定していました)

つまり、(a_i,b_i)の組は(0,0),(1,0),(1,1)の3パターンあるぞということ。

そして、この3パターンを考慮して桁DPで求めていく。

a = 2a'
b = 2b'

このa',b'a,bを1ビット右シフトしたものと同じ。

2a' + 2b' \le x \to a' + b' \le x/2
(2a')\oplus(2b') \le y \to a' \oplus b' \le y/2
2a' \ge 2b' \to a' \ge b'

dp(x,y) = dp(x/2,y) + dp((x-2)/2,y/2) + dp((x-1)/2,(y-1)/2)

パターン(0,0)である時の遷移先 = dp(x/2,y/2)
パターン(1,0)である時の遷移先 = dp((x-1)/2,(y-1)/2)
パターン(1,1)である時の遷移先 = dp((x-2)/2,y/2)

youtu.be

ソースコード

import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.HashMap;
import java.util.NoSuchElementException;
import java.util.Objects;

public class Main {
	static int MOD = (int)1e9 + 7;
	long N;
	HashMap<Key,Long> map;

	private class Key implements Comparable<Key>{
		long x,y;
		public Key(long x,long y){
			this.x = x;
			this.y = y;
		}

		public int compareTo(Key p){
			if(Long.compare(this.x,p.x) == 0){
				return Long.compare(this.y, p.y);
			}
			return Long.compare(this.x, p.x);
		}

		public boolean equals(Object o){

			if(this == o){
				return true;
			}

			if(o instanceof Key){

				Key another = (Key)o;

				return this.x == another.x && this.y == another.y;
			}
			return false;
		}

		public int hashCode(){
			return Objects.hash(Long.hashCode(this.x),Long.hashCode(this.y));
		}
	}

	public long dfs(long S,long X){

		if(S == 0)return 1;

		Key key = new Key(S,X);

		if(map.containsKey(key)){
			return map.get(key);
		}

		long ret = 0;
		ret += dfs(S >> 1,X >> 1) % MOD;
		ret %= MOD;

		if(S > 1){
			ret += dfs((S - 2) >> 1,X >> 1) % MOD;
			ret %= MOD;
		}

		ret += dfs((S - 1) >> 1,(X - 1) >> 1) % MOD;
		ret %= MOD;

		map.put(key, ret);
		return ret;
	}

	public void solve() {
		N = nextLong();
		map = new HashMap<Key,Long>();
		out.println(dfs(N,N));
	}

	public static void main(String[] args) {
		out.flush();
		new Main().solve();
		out.close();
	}

	/* Input */
	private static final InputStream in = System.in;
	private static final PrintWriter out = new PrintWriter(System.out);
	private final byte[] buffer = new byte[2048];
	private int p = 0;
	private int buflen = 0;

	private boolean hasNextByte() {
		if (p < buflen)
			return true;
		p = 0;
		try {
			buflen = in.read(buffer);
		} catch (IOException e) {
			e.printStackTrace();
		}
		if (buflen <= 0)
			return false;
		return true;
	}

	public boolean hasNext() {
		while (hasNextByte() && !isPrint(buffer[p])) {
			p++;
		}
		return hasNextByte();
	}

	private boolean isPrint(int ch) {
		if (ch >= '!' && ch <= '~')
			return true;
		return false;
	}

	private int nextByte() {
		if (!hasNextByte())
			return -1;
		return buffer[p++];
	}

	public String next() {
		if (!hasNext())
			throw new NoSuchElementException();
		StringBuilder sb = new StringBuilder();
		int b = -1;
		while (isPrint((b = nextByte()))) {
			sb.appendCodePoint(b);
		}
		return sb.toString();
	}

	public int nextInt() {
		return Integer.parseInt(next());
	}

	public long nextLong() {
		return Long.parseLong(next());
	}

	public double nextDouble() {
		return Double.parseDouble(next());
	}
}