tookunn’s diary

主に競技プログラミング関係

RUPC2016 Day1 A 秤 / Steelyard

考察

吊らされたおもりと対称的におもりを吊っていく。
具体的には-w_iに同じ重さのおもりを吊っていく。

ソースコード

import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.NoSuchElementException;
 
public class Main {
    int L,N;
    int[] x,w;
    ArrayList<Integer>[] list;
    @SuppressWarnings("unchecked")
    public void solve() {
        L = nextInt();
        N = nextInt();
        x = new int[N];
        w = new int[N];
 
        list = new ArrayList[2 * L + 1];
        for(int i = 0;i < 2 * L + 1;i++){
            list[i] = new ArrayList<Integer>();
        }
        for(int i = 0;i < N;i++){
            x[i] = nextInt();
            w[i] = nextInt();
            list[-x[i] + L].add(w[i]);
        }
        out.println(N);
        for(int i = 0;i < 2 * L + 1;i++){
            for(int weight : list[i]){
                out.println((i-L) + " " + weight);
            }
        }
 
 
    }
    public static void main(String[] args) {
        out.flush();
        new Main().solve();
        out.close();
    }
 
    /* Input */
    private static final InputStream in = System.in;
    private static final PrintWriter out = new PrintWriter(System.out);
    private final byte[] buffer = new byte[2048];
    private int p = 0;
    private int buflen = 0;
 
    private boolean hasNextByte() {
        if (p < buflen)
            return true;
        p = 0;
        try {
            buflen = in.read(buffer);
        } catch (IOException e) {
            e.printStackTrace();
        }
        if (buflen <= 0)
            return false;
        return true;
    }
 
    public boolean hasNext() {
        while (hasNextByte() && !isPrint(buffer[p])) {
            p++;
        }
        return hasNextByte();
    }
 
    private boolean isPrint(int ch) {
        if (ch >= '!' && ch <= '~')
            return true;
        return false;
    }
 
    private int nextByte() {
        if (!hasNextByte())
            return -1;
        return buffer[p++];
    }
 
    public String next() {
        if (!hasNext())
            throw new NoSuchElementException();
        StringBuilder sb = new StringBuilder();
        int b = -1;
        while (isPrint((b = nextByte()))) {
            sb.appendCodePoint(b);
        }
        return sb.toString();
    }
 
    public int nextInt() {
        return Integer.parseInt(next());
    }
 
    public long nextLong() {
        return Long.parseLong(next());
    }
 
    public double nextDouble() {
        return Double.parseDouble(next());
    }
}

AtCoderBeginnerContest 010 D 浮気予防

問題

abc010.contest.atcoder.jp

解説見て蟻本写経して通しました。

考察

当初、複数のp_iを含む連結成分を探して,橋を見つけてそれを消していけば良いのかなぐらいに考えてたけど、そこから何も分からなかったので解説を見て最大フローを求める問題と分かった。
フローを扱う問題を解いたことなかったので蟻本や,以下記事を参考にしました。

参考にさせていただいた記事:
even-eko.hatenablog.com

qiita.com

公式の解説スライドがとても分かりやすいので、それを見れば大体どういうことをすれば良いのかなんとなく分かった。

  • G個のp_iに対して,始点を0(高橋君のID),終点をN(N + 1個目の頂点を追加)とするためにG個のp_iそれぞれと終点Nとの間に辺を追加する
  • 通常の辺と逆辺の分、つまり1つの友人関係に対して2つの辺を追加していく(問題の設定上無向グラフなので)

あとはFord-Fulkerson法と呼ばれる最大フローを求めるアルゴリズムを使ってやると求まる。らしい

公式解説:
AtCoder Beginner Contest 010 解説

ソースコード

import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.NoSuchElementException;

public class Main {
	static final int INF = (int)1e9 + 7;
	int N,G,E;
	ArrayList<Edge>[] graph;
	boolean[] used;

	class Edge{
		int to,cap,rev;
		public Edge(int to,int cap,int rev){
			this.to = to;
			this.cap = cap;
			this.rev = rev;
		}
	}

	public void addEdge(int from,int to,int cap){
		graph[from].add(new Edge(to,cap,graph[to].size()));
		graph[to].add(new Edge(from,0,graph[from].size()-1));
	}

	public int dfs(int v,int t,int f){

		if(v == t)return f;
		used[v] = true;

		for(int i = 0;i < graph[v].size();i++){
			Edge e = graph[v].get(i);

			if(!used[e.to] && e.cap > 0){
				int d = dfs(e.to,t,Math.min(f,e.cap));
				if(d > 0){
					e.cap -= d;
					graph[e.to].get(e.rev).cap += d;
					return d;
				}
			}
		}
		return 0;
	}

	public int maxFlow(int s,int t){

		int flow = 0;
		used = new boolean[N+1];
		for(;;){
			Arrays.fill(used,false);
			int f = dfs(s,t,INF);
			if(f == 0)return flow;
			flow += f;
		}

	}

	@SuppressWarnings("unchecked")
	public void solve() {
		N = nextInt();
		G = nextInt();
		E = nextInt();

		graph = new ArrayList[N + 1];
		for(int i = 0;i < N + 1;i++){
			graph[i] = new ArrayList<Edge>();
		}

		for(int i = 0;i < G;i++){
			int p = nextInt();
			addEdge(p,N,1);
		}

		for(int i = 0;i < E;i++){

			int a = nextInt();
			int b = nextInt();
			addEdge(a,b,1);
			addEdge(b,a,1);
		}

		out.println(maxFlow(0,N));
	}

	public static void main(String[] args) {
		out.flush();
		new Main().solve();
		out.close();
	}

	/* Input */
	private static final InputStream in = System.in;
	private static final PrintWriter out = new PrintWriter(System.out);
	private final byte[] buffer = new byte[2048];
	private int p = 0;
	private int buflen = 0;

	private boolean hasNextByte() {
		if (p < buflen)
			return true;
		p = 0;
		try {
			buflen = in.read(buffer);
		} catch (IOException e) {
			e.printStackTrace();
		}
		if (buflen <= 0)
			return false;
		return true;
	}

	public boolean hasNext() {
		while (hasNextByte() && !isPrint(buffer[p])) {
			p++;
		}
		return hasNextByte();
	}

	private boolean isPrint(int ch) {
		if (ch >= '!' && ch <= '~')
			return true;
		return false;
	}

	private int nextByte() {
		if (!hasNextByte())
			return -1;
		return buffer[p++];
	}

	public String next() {
		if (!hasNext())
			throw new NoSuchElementException();
		StringBuilder sb = new StringBuilder();
		int b = -1;
		while (isPrint((b = nextByte()))) {
			sb.appendCodePoint(b);
		}
		return sb.toString();
	}

	public int nextInt() {
		return Integer.parseInt(next());
	}

	public long nextLong() {
		return Long.parseLong(next());
	}

	public double nextDouble() {
		return Double.parseDouble(next());
	}
}

AtCoderBeginnerContest 021 D 多重ループ

問題

abc021.contest.atcoder.jp

過去問埋めで久々のABC D問題の自力AC

考察

1 \le a_1 \le a_2 \le ... \le a_k \le nという性質から,a_1,a_2,a_3,a_4...,a_k内の複数のa_iの間で重複した値が許されるということが分かる。

これは結局,1\sim nまでの範囲の数値からk個の数値を重複を許し,取り出すことと同じ。

例えば,1\sim100までの範囲の数値から4個の数値を重複を許し,取り出すことを考えると

(78,20,9,57)という風な整数の組を取り出すことができ,整数の組中の数値を昇順ソートすると(9,20,57,78)となり問題文の条件(1 \le a_1 \le a_2 \le a_3 \le a_4 \le n)を満たす。

(22,65,22,12)のように重複する値が入っていても(12,22,22,65)になり,これも条件を満たすことになる。


以上の事と整数の組(a_1,a_2,a_3,a_4,...,a_k)の個数と解が等しいというところから重複組み合わせ_nH_k = _{n+k-1}C_kを求めれば解が求まることが分かる。

ソースコード

import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.NoSuchElementException;

public class Main {
	static final int MOD = (int)1e9 + 7;
	int N,K;

	public long modPow(long x,long y){

		if(y == 0){
			return 1;
		}
		else if(y == 1){
			return x;
		}
		else if(y % 2 == 0){
			long z = modPow(x,y / 2);
			return z * z % MOD;
		}else{
			return (modPow(x,y - 1) * x) % MOD;
		}
	}

	public long nCk(int n,int k)
	{
		long a = 1;
		for(int i = 0;i < k;i++)
		{
			a *=  (n - i);
			a %= MOD;
		}

		long b = 1;
		for(int i = k;i >= 2;i--){
			b *= i;
			b %= MOD;
		}
		return (a * modPow(b,MOD - 2) % MOD) % MOD;
	}

	public void solve() {
		N = nextInt();
		K = nextInt();

		out.println(nCk(N + K - 1,K));
	}

	public static void main(String[] args) {
		out.flush();
		new Main().solve();
		out.close();
	}

	/* Input */
	private static final InputStream in = System.in;
	private static final PrintWriter out = new PrintWriter(System.out);
	private final byte[] buffer = new byte[2048];
	private int p = 0;
	private int buflen = 0;

	private boolean hasNextByte() {
		if (p < buflen)
			return true;
		p = 0;
		try {
			buflen = in.read(buffer);
		} catch (IOException e) {
			e.printStackTrace();
		}
		if (buflen <= 0)
			return false;
		return true;
	}

	public boolean hasNext() {
		while (hasNextByte() && !isPrint(buffer[p])) {
			p++;
		}
		return hasNextByte();
	}

	private boolean isPrint(int ch) {
		if (ch >= '!' && ch <= '~')
			return true;
		return false;
	}

	private int nextByte() {
		if (!hasNextByte())
			return -1;
		return buffer[p++];
	}

	public String next() {
		if (!hasNext())
			throw new NoSuchElementException();
		StringBuilder sb = new StringBuilder();
		int b = -1;
		while (isPrint((b = nextByte()))) {
			sb.appendCodePoint(b);
		}
		return sb.toString();
	}

	public int nextInt() {
		return Integer.parseInt(next());
	}

	public long nextLong() {
		return Long.parseLong(next());
	}

	public double nextDouble() {
		return Double.parseDouble(next());
	}
}

第3回 ドワンゴからの挑戦状 予選 B ニコニコレベル

考察

文字列を左から順に見ていき,'2','5','?'の時で場合分けして状態を遷移する。

2の場合
  • i番目の文字が2である状態からi+1番目の文字が5である状態に遷移する
5の場合
  • i番目の文字が5である状態からi+1番目の文字が2である状態に遷移する
  • しかし,ニコニコ文字列は252525..のように2から始まる文字列なので,i-1番目の文字が2である必要がある
?の場合
  • 2,5の場合で行う遷移を両方行えばよい

dp[i][j] = i番目の文字がjであり,i番目の文字まで見た時の連続した最長のニコニコ文字列の長さ


書いてる途中で気付いたけど,dp[N][10]じゃなくてdp[N][2]で良さそうな気がする(2,5の2通りしか考慮しないため)

ソースコード

import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.NoSuchElementException;

public class Main {
	String T;
	char[] ch;
	int N;
	int[][] dp;

	public void solve() {
		T = next();
		ch = T.toCharArray();
		N = T.length();
		dp = new int[N+1][10];
		//dp[i番目][i番目の文字]
		for(int i = 0;i < N;i++){
			if(ch[i] == '?'){
				if(dp[i][5]%2==1)dp[i + 1][2] = Math.max(dp[i+1][2],dp[i][5]+1);
				dp[i + 1][5] = Math.max(dp[i+1][5],dp[i][2]+1);
			}else if(ch[i] == '2'){
				dp[i + 1][5] = Math.max(dp[i+1][5],dp[i][2]+1);
			}else if(ch[i] == '5'){
				if(dp[i][5]%2==1)dp[i + 1][2] = Math.max(dp[i+1][2],dp[i][5]+1);
			}
		}

		int ans = 0;
		for(int i = 0;i < N+1;i++){
			for(int j = 0;j < 10;j++){
				ans = Math.max(ans,dp[i][j]);
			}
		}
		out.println(ans/2 * 2);
	}

	public static void main(String[] args) {
		out.flush();
		new Main().solve();
		out.close();
	}

	/* Input */
	private static final InputStream in = System.in;
	private static final PrintWriter out = new PrintWriter(System.out);
	private final byte[] buffer = new byte[2048];
	private int p = 0;
	private int buflen = 0;

	private boolean hasNextByte() {
		if (p < buflen)
			return true;
		p = 0;
		try {
			buflen = in.read(buffer);
		} catch (IOException e) {
			e.printStackTrace();
		}
		if (buflen <= 0)
			return false;
		return true;
	}

	public boolean hasNext() {
		while (hasNextByte() && !isPrint(buffer[p])) {
			p++;
		}
		return hasNextByte();
	}

	private boolean isPrint(int ch) {
		if (ch >= '!' && ch <= '~')
			return true;
		return false;
	}

	private int nextByte() {
		if (!hasNextByte())
			return -1;
		return buffer[p++];
	}

	public String next() {
		if (!hasNext())
			throw new NoSuchElementException();
		StringBuilder sb = new StringBuilder();
		int b = -1;
		while (isPrint((b = nextByte()))) {
			sb.appendCodePoint(b);
		}
		return sb.toString();
	}

	public int nextInt() {
		return Integer.parseInt(next());
	}

	public long nextLong() {
		return Long.parseLong(next());
	}

	public double nextDouble() {
		return Double.parseDouble(next());
	}
}

AtCoderRegularContest 065 F シャッフル / Shuffling

問題

arc065.contest.atcoder.jp

さすがF問題ということで解説見て,解説放送見て,他の方の提出コードを見て,時間をかけてやっと理解(完全ではない)出来ました。復習必須に感じたのでメモとして記事にします。

考察

まず、まとめられるl_i,r_iはまとめていく。

まとめられるl_i,r_iとは
  • i \le jであり,r_j \le r_iの場合,区間(l_i,r_i)区間(l_j,r_j)を完全に含んでいるのでj番目の操作は行わなくて良い。つまり行わなくて良いj番目の操作をi番目の操作としてまとめる。

問題の制約上,l_iは非減少していくので,文字列の左から順に一文字ずつ文字を決定していくことが出来る。

なので方針としては文字列を左から見ていき,l_i番目となる文字に到達した時,l_iからr_iまでの部分文字列内に含まれる1の数(またはl_i番目の文字まで見た時に配置した1の数)を保持しておき,i番目に0を配置するか,1を配置するかで分岐していく。

そしてこれをDPとして解いていく。

参照:
解説放送・公式解説
www.youtube.com

Editorial - AtCoder Regular Contest 065 | AtCoder

ソースコード

  • dp[i][j] = i番目の文字まで見て、これまでj個の1を配置済みの時の組み合わせ
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.NoSuchElementException;

public class Main {
	static final int MOD = (int)1e9 + 7;
	int N,M;
	char[] ch;
	int[] r,l,newR,sum;
	long[][] dp;

	public void solve() {
		N = nextInt();
		M = nextInt();
		ch = next().toCharArray();

		l = new int[M];
		r = new int[M];
		for(int i = 0;i < M;i++){
			l[i] = nextInt() - 1;
			r[i] = nextInt() - 1;
		}

		newR = new int[N];
		for(int i = 0;i < N;i++){
			newR[i] = i;
		}
		for(int i = 0;i < M;i++){
			newR[l[i]] = Math.max(newR[l[i]],r[i]);
		}

		for(int i = 1;i < N;i++){
			newR[i] = Math.max(newR[i], newR[i-1]);
		}

		sum = new int[N];
		for(int i = 0;i < N;i++){
			sum[i] = ch[i] - '0';
		}
		for(int i = 0;i < N - 1;i++){
			sum[i + 1] += sum[i];
		}

		dp = new long[N + 1][N + 1];
		dp[0][0] = 1L;

		//i = 今見ているSの位置
		for(int i = 0;i < N;i++){

			//j = これまで配置した1の数
			for(int j = 0;j <= i;j++){

				if(dp[i][j] == 0)continue;

				/*
				 * j <= sum[newR[i]] jがnewR[i]]の位置までに存在する1の数以下(これを超えていたら存在する1の数より多いのでありえない)
				 * sum[newR[i]] - j <= newR[i] - i 残りの配置しなければならない1の数 <= newR[i]の位置まで配置できる0or1の数(これも超えていたら1が余分に存在している)
				 *
				 * 以上の条件を満たした時、dp[i + 1][j]に遷移する(iの位置に0を配置して次の状態に移る)
				 */
				if(j <= sum[newR[i]] && sum[newR[i]] - j <= newR[i] - i){
					dp[i + 1][j] += dp[i][j]%MOD;
					dp[i + 1][j] %= MOD;
				}

				/*
				 * j <= sum[newR[i]] - 1 iの位置に1を配置する分をsum[newR[i]]から引いてもちゃんとj以上になる(これがj未満になると矛盾していることになる)
				 * sum[newR[i]] - j - 1 <= newR[i] - i 今iの位置に配置する分(-1)とこれまで配置した分(j)をsum[newR[i]]から引いた残りの数(これはnewR[i]の位置までに配置しなければならない1の数)が
				 * 残りの配置できる位置の数以下(これが上回っていると配置することが出来ない1が出てきてしまう)
				 * 
				 * 以上の条件を満たした時、dp[i + 1][j + 1]に遷移する(iの位置に1を配置して次の状態に移る)
				 */
				if(j  + 1<= sum[newR[i]] && sum[newR[i]] - j - 1 <= newR[i] - i){
					dp[i + 1][j + 1] += dp[i][j]%MOD;
					dp[i + 1][j + 1] %= MOD;
				}
			}
		}

		out.println(dp[N][sum[N-1]]);
	}

	public static void main(String[] args) {
		out.flush();
		new Main().solve();
		out.close();
	}

	/* Input */
	private static final InputStream in = System.in;
	private static final PrintWriter out = new PrintWriter(System.out);
	private final byte[] buffer = new byte[2048];
	private int p = 0;
	private int buflen = 0;

	private boolean hasNextByte() {
		if (p < buflen)
			return true;
		p = 0;
		try {
			buflen = in.read(buffer);
		} catch (IOException e) {
			e.printStackTrace();
		}
		if (buflen <= 0)
			return false;
		return true;
	}

	public boolean hasNext() {
		while (hasNextByte() && !isPrint(buffer[p])) {
			p++;
		}
		return hasNextByte();
	}

	private boolean isPrint(int ch) {
		if (ch >= '!' && ch <= '~')
			return true;
		return false;
	}

	private int nextByte() {
		if (!hasNextByte())
			return -1;
		return buffer[p++];
	}

	public String next() {
		if (!hasNext())
			throw new NoSuchElementException();
		StringBuilder sb = new StringBuilder();
		int b = -1;
		while (isPrint((b = nextByte()))) {
			sb.appendCodePoint(b);
		}
		return sb.toString();
	}

	public int nextInt() {
		return Integer.parseInt(next());
	}

	public long nextLong() {
		return Long.parseLong(next());
	}

	public double nextDouble() {
		return Double.parseDouble(next());
	}
}
  • dp[i][j] = i番目までの文字を見た時,iに対応するr_i(どこまでシャッフルするかの最終地点)までに含まれる1の個数
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.NoSuchElementException;

public class Main {
	static final int MOD = (int)1e9 + 7;
	int N,M;
	char[] ch;
	int[] r,l,newR,sum;
	long[][] dp;


	public long dfs(int L,int R,int C){

		if(L == N){
			return 1;
		}

		if(dp[L][C] != -1)return dp[L][C];

		long ret = 0;
		int add = 0;

		if(R < newR[L+1]){
			add = sum[newR[L+1]] - sum[R];
		}

		//1を配置
		if(C > 0){
			ret += dfs(L + 1,newR[L+1],C - 1 + add) % MOD;
		}

		//0を配置
		if(C < R - L  + 1){
			ret += dfs(L + 1,newR[L+1],C + add) % MOD;
		}
		ret %= MOD;
		return dp[L][C] = ret;

	}

	public void solve() {
		N = nextInt();
		M = nextInt();
		ch = next().toCharArray();

		l = new int[M];
		r = new int[M];
		for(int i = 0;i < M;i++){
			l[i] = nextInt() - 1;
			r[i] = nextInt() - 1;
		}

		newR = new int[N + 1];
		for(int i = 0;i < N;i++){
			newR[i] = i;
		}
		for(int i = 0;i < M;i++){
			newR[l[i]] = Math.max(newR[l[i]],r[i]);
		}
		for(int i = 1;i < N;i++){
			newR[i] = Math.max(newR[i], newR[i-1]);
		}

		sum = new int[N];
		for(int i = 0;i < N;i++){
			sum[i] = ch[i] - '0';
		}
		for(int i = 0;i < N - 1;i++){
			sum[i + 1] += sum[i];
		}

		dp = new long[N + 1][N + 1];
		for(int i = 0;i < N + 1;i++){
			Arrays.fill(dp[i], -1);
		}

		out.println(dfs(0,newR[0],sum[newR[0]]));
	}

	public static void main(String[] args) {
		out.flush();
		new Main().solve();
		out.close();
	}
}

AtCoderGrandContest 005 B

問題文

agc005.contest.atcoder.jp

解説見ました。

考察

愚直にN^2個の区間ごとに最小値を計算しても計算量がO(N^2)になるのでTLEになる。

ここで最小値になる値に注目して区間を考える。

具体的な例として問題文の入力例2で考える。

i0123
a1324
a_i = 1の値を見ると最小値が1になる範囲はleft(iより左に存在する要素の数+1)*right(iより右に存在する要素の数+1)の数だけ存在する。 これを1~Nまでの値に対して行っていけば解が求まる。 つまり,1~Nまでの値a_iを見ていった時に、

  • iより小さく、a_iより小さい値の最大のiを持つインデックス
  • iより大きく、a_iより小さい値の最小のiを持つインデックス

が分かれば良い。

ソースコード

import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.NoSuchElementException;
import java.util.TreeSet;

public class Main {
	int N;
	int[] a,b;
	TreeSet<Integer> set;
	public void solve() {
		N = nextInt();
		a = new int[N];
		b = new int[N + 1];
		for(int i = 0;i < N;i++){
			a[i] = nextInt();
			b[a[i]] = i;
		}

		set = new TreeSet<Integer>();

		long ans = 0;

		for(int i = 1;i <= N;i++){
			set.add(b[i]);
			int index = b[i];

			//indexより大きく、最小のインデックスを持つiより小さい要素のインデックス
			Integer right = set.higher(index);
			//indexより小さく、最大のインデックスを持つiより小さい要素のインデックス
			Integer left = set.lower(index);


			if(right == null){
				right = N;
			}
			right = right - index;

			if(left == null){
				left = -1;
			}
			left = index - left;

			ans += (long)right * left * i;
		}

		out.println(ans);

	}

	public static void main(String[] args) {
		out.flush();
		new Main().solve();
		out.close();
	}

	/* Input */
	private static final InputStream in = System.in;
	private static final PrintWriter out = new PrintWriter(System.out);
	private final byte[] buffer = new byte[2048];
	private int p = 0;
	private int buflen = 0;

	private boolean hasNextByte() {
		if (p < buflen)
			return true;
		p = 0;
		try {
			buflen = in.read(buffer);
		} catch (IOException e) {
			e.printStackTrace();
		}
		if (buflen <= 0)
			return false;
		return true;
	}

	public boolean hasNext() {
		while (hasNextByte() && !isPrint(buffer[p])) {
			p++;
		}
		return hasNextByte();
	}

	private boolean isPrint(int ch) {
		if (ch >= '!' && ch <= '~')
			return true;
		return false;
	}

	private int nextByte() {
		if (!hasNextByte())
			return -1;
		return buffer[p++];
	}

	public String next() {
		if (!hasNext())
			throw new NoSuchElementException();
		StringBuilder sb = new StringBuilder();
		int b = -1;
		while (isPrint((b = nextByte()))) {
			sb.appendCodePoint(b);
		}
		return sb.toString();
	}

	public int nextInt() {
		return Integer.parseInt(next());
	}

	public long nextLong() {
		return Long.parseLong(next());
	}

	public double nextDouble() {
		return Double.parseDouble(next());
	}
}

AtCoderBeginnerContest 050 D

問題

abc050.contest.atcoder.jp

解説見て通しました。

考察

これ以下の記述は自分が公式解説放送見ながら書いたメモみたいなものです。

a+b = x(xは問題文でのv)
a \oplus b = y(yは問題文でのu)
以上の式がある。
a_i(aのiビット目)とb_i(bのiビット目)に注目して考えると、a_ib_iを入れ替えてもa,b,x,yは変わらず同じ値になる。

ただし、(a_i,b_i) = (1,0)or(0,1)の場合はx,yの値は変わらないがa,bの値が変わってしまう。これだと(x,y)の組に対して重複して(a,b)を考えてしまうので、(a_i,b_i) = (1,0)または(0,1)に固定してしまう。(公式解説放送では(a_i,b_i) = (1,0)に固定していました)

つまり、(a_i,b_i)の組は(0,0),(1,0),(1,1)の3パターンあるぞということ。

そして、この3パターンを考慮して桁DPで求めていく。

a = 2a'
b = 2b'

このa',b'a,bを1ビット右シフトしたものと同じ。

2a' + 2b' \le x \to a' + b' \le x/2
(2a')\oplus(2b') \le y \to a' \oplus b' \le y/2
2a' \ge 2b' \to a' \ge b'

dp(x,y) = dp(x/2,y) + dp((x-2)/2,y/2) + dp((x-1)/2,(y-1)/2)

パターン(0,0)である時の遷移先 = dp(x/2,y/2)
パターン(1,0)である時の遷移先 = dp((x-1)/2,(y-1)/2)
パターン(1,1)である時の遷移先 = dp((x-2)/2,y/2)

youtu.be

ソースコード

import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.HashMap;
import java.util.NoSuchElementException;
import java.util.Objects;

public class Main {
	static int MOD = (int)1e9 + 7;
	long N;
	HashMap<Key,Long> map;

	private class Key implements Comparable<Key>{
		long x,y;
		public Key(long x,long y){
			this.x = x;
			this.y = y;
		}

		public int compareTo(Key p){
			if(Long.compare(this.x,p.x) == 0){
				return Long.compare(this.y, p.y);
			}
			return Long.compare(this.x, p.x);
		}

		public boolean equals(Object o){

			if(this == o){
				return true;
			}

			if(o instanceof Key){

				Key another = (Key)o;

				return this.x == another.x && this.y == another.y;
			}
			return false;
		}

		public int hashCode(){
			return Objects.hash(Long.hashCode(this.x),Long.hashCode(this.y));
		}
	}

	public long dfs(long S,long X){

		if(S == 0)return 1;

		Key key = new Key(S,X);

		if(map.containsKey(key)){
			return map.get(key);
		}

		long ret = 0;
		ret += dfs(S >> 1,X >> 1) % MOD;
		ret %= MOD;

		if(S > 1){
			ret += dfs((S - 2) >> 1,X >> 1) % MOD;
			ret %= MOD;
		}

		ret += dfs((S - 1) >> 1,(X - 1) >> 1) % MOD;
		ret %= MOD;

		map.put(key, ret);
		return ret;
	}

	public void solve() {
		N = nextLong();
		map = new HashMap<Key,Long>();
		out.println(dfs(N,N));
	}

	public static void main(String[] args) {
		out.flush();
		new Main().solve();
		out.close();
	}

	/* Input */
	private static final InputStream in = System.in;
	private static final PrintWriter out = new PrintWriter(System.out);
	private final byte[] buffer = new byte[2048];
	private int p = 0;
	private int buflen = 0;

	private boolean hasNextByte() {
		if (p < buflen)
			return true;
		p = 0;
		try {
			buflen = in.read(buffer);
		} catch (IOException e) {
			e.printStackTrace();
		}
		if (buflen <= 0)
			return false;
		return true;
	}

	public boolean hasNext() {
		while (hasNextByte() && !isPrint(buffer[p])) {
			p++;
		}
		return hasNextByte();
	}

	private boolean isPrint(int ch) {
		if (ch >= '!' && ch <= '~')
			return true;
		return false;
	}

	private int nextByte() {
		if (!hasNextByte())
			return -1;
		return buffer[p++];
	}

	public String next() {
		if (!hasNext())
			throw new NoSuchElementException();
		StringBuilder sb = new StringBuilder();
		int b = -1;
		while (isPrint((b = nextByte()))) {
			sb.appendCodePoint(b);
		}
		return sb.toString();
	}

	public int nextInt() {
		return Integer.parseInt(next());
	}

	public long nextLong() {
		return Long.parseLong(next());
	}

	public double nextDouble() {
		return Double.parseDouble(next());
	}
}