HDU-6632 (杭电多校第5场 discrete logar

2019-08-06  本文已影响0人  叔丁基锂_

题意:已知a,b,p,求x使得a^x\equiv b (\bmod p), p<10^{18} 并且p-1=2^m3^n

题解:利用Pohlig-Hellman Algorithm来做离散对数,复杂度是\mathcal{O}\left(\sum_{i} e_{i}\left(\log n+\sqrt{p}_{i}\right)\right) ,其中\prod_i p_i^{e_i}=n,n是\mathbb{Z}_p 的阶(在这里n=p-1)

Pohlig Hellman Algorithm

——from Wikipedia

注意本算法求解y=g^x (\bmod p) 要求g 是p的一个原根,所以我们需要这个问题转化为g^{pa\cdot x}\equiv g^{pb} (\bmod p) , 从而 pa\cdot x\equiv pb (\bmod (p-1)),最后左右两边除以最大公因数然后用扩展gcd求逆元即可。

Wikipedia的具体过程比较抽象,建议配合这个例子 一起食用

#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <map>
#include <unordered_map>
#define int __int128
#define FOR(i, x, y) for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i)
using namespace std;
using ll = int64_t;

int bin(int x, int n, int MOD)
{
    int ret = MOD != 1;
    n = (n + MOD - 1) % (MOD - 1);
    for (x %= MOD; n; n >>= 1, x = x * x % MOD)
        if (n & 1)
            ret = ret * x % MOD;
    return ret;
}
inline int get_inv(int x, int p) { return bin(x, p - 2, p); }

bool flag;
int go(int id, int up, int now, int a, int b, int p)
{
    if (now == up)
        return 0;
    int tower = (p - 1) / bin(id, now + 1, p);
    auto bi = bin(b, tower, p);

    for (int i = 0; i < id; I++)
    {
        auto c = i * (p - 1) / id;
        if (bin(a, c, p) == bi)
        {
            int tmp = i * bin(id, now, p);
            auto nextb = (b * bin(a, (p - 1 - i * bin(id, now, p)) % (p - 1), p)) % p;
            return tmp + go(id, up, now + 1, a, nextb, p);
        }
    }
    flag = false;
    return -1;
}

int ex_gcd(int a, int b, int &x, int &y)
{
    if (b == 0)
    {
        x = 1;
        y = 0;
        return a;
    }
    int ret = ex_gcd(b, a % b, y, x);
    y -= a / b * x;
    return ret;
}

int CRT(int *m, int *r, int n)
{
    if (!n)
        return 0;
    int M = m[0], R = r[0], x, y, d;
    FOR(i, 1, n)
    {
        d = ex_gcd(M, m[i], x, y);
        if ((r[i] - R) % d)
            return -1;
        x = (r[i] - R) / d * x % (m[i] / d);
        R += x * M;
        M = M / d * m[I];
        R %= M;
    }
    return R >= 0 ? R : R + M;
}

int find_smallest_primitive_root(int p, const map<int, int> &ma)
{
    static unordered_map<int, int> ans;
    if (ans[p])
        return ans[p];
    for (int i = 2; i < p; I++)
    {
        bool flag = true;
        for (auto pa : ma)
        {
            if (bin(i, (p - 1) / pa.first, p) == 1)
            {
                flag = false;
                break;
            }
        }
        if (flag)
            return ans[p] = I;
    }
    return -1;
}

int solve(int a, int b, int p, const map<int, int> &ma)
{
    // flag = true;
    int m[ma.size()], r[ma.size()];
    int cnt = 0;
    for (auto pa : ma)
    {
        r[cnt] = go(pa.first, pa.second, 0ll, a, b, p);
        m[cnt] = bin(pa.first, pa.second, p);
        cnt++;
    }
    ll ans = CRT(m, r, ma.size());
    return ans;
}

int32_t main()
{
    ios::sync_with_stdio(false);
    int32_t t;
    cin >> t;
    while (t--)
    {
        int64_t a, b, p;
        cin >> p >> a >> b;
        int tmp = p - 1;
        map<int, int> ma;
        for (int i = 2; i <= 5; i++)
        {
            while (tmp % i == 0)
            {
                ma[i]++;
                tmp /= i;
            }
        }
        int rt = find_smallest_primitive_root(p, ma);
        auto pa = solve(rt, a, p, ma);
        auto pb = solve(rt, b, p, ma);
        if (pa == 0)
        {
            if (pb == 0)
            {
                cout << 0 << endl;
            }
            else
            {
                cout << -1 << endl;
            }
        }
        else if (pb % pa == 0)
        {
            cout << ll(pb / pa) << endl;
        }
        else
        {
            int x, y;
            int g = ex_gcd(pa, p - 1, x, y);
            if (pb % g)
            {
                cout << -1 << endl;
            }
            else
            {
                pa /= g;
                pb /= g;
                int mmod = (p - 1) / g;
                ex_gcd(pa, mmod, x, y);
                cout << ll(pb * (x + mmod) % mmod) << endl;
            }
        }
    }
}

其它的参考:

上一篇下一篇

猜你喜欢

热点阅读