信息学竞赛题解(IO题解)算法与数据结构数据结构和算法分析

BZOJ-3531: [Sdoi2014]旅行(树链剖分+线段树

2019-02-16  本文已影响0人  AmadeusChan

题目:http://www.lydsy.com/JudgeOnline/problem.php?id=3531

好久没发过题解了。。。额。。这题树链剖分之后暴力维护10W棵线段树就可以了额。。。

代码:

#include <cstdio>

#include <algorithm>

#include <cstring>

#include <vector>

 

using namespace std ;

 

#define AddEdge( s , t ) Add( s , t ) , Add( t , s )

#define Add( s , t ) E[ s ].push_back( t )

#define travel( x ) for ( vector < int > :: iterator p = E[ x ].begin(  ) ; p != E[ x ].end(  ) ; ++ p )

 

#define L( t ) sgt[ t ].left

#define R( t ) sgt[ t ].right

#define S( t ) sgt[ t ].sum

#define M( t ) sgt[ t ].max_val

 

const int maxv = 3010000 ;

const int maxn = 100100 ;

 

struct node {

    int left , right , sum , max_val ;

} sgt[ maxv ] ;

 

int V = 0 ;

 

void INIT(  ) {

    L( 0 ) = R( 0 ) = S( 0 ) = M( 0 ) = 0 ;

}

 

void update( int t ) {

    S( t ) = S( L( t ) ) + S( R( t ) ) ;

    M( t ) = max( M( L( t ) ) , M( R( t ) ) ) ;

}

 

void Change( int l , int r , int pos , int val , int &t ) {

    if ( ! t ) t = ++ V ;

    if ( l == r ) {

        S( t ) = M( t ) = val ;

        return ;

    }

    int mid = ( l + r ) >> 1 ;

    if ( pos <= mid ) Change( l , mid , pos , val , L( t ) ) ; else

    Change( mid + 1 , r , pos , val , R( t ) ) ;

    update( t ) ;

}

 

int query_max( int l , int r , int _l , int _r , int t ) {

    if ( ! t ) return 0 ;

    if ( l == _l && r == _r ) return M( t ) ;

    int mid = ( _l + _r ) >> 1 ;

    if ( r <= mid ) return query_max( l , r , _l , mid , L( t ) ) ;

    if ( l > mid ) return query_max( l , r , mid + 1 , _r , R( t ) ) ;

    return max( query_max( l , mid , _l , mid , L( t ) ) , query_max( mid + 1 , r , mid + 1 , _r , R( t ) ) ) ;

}

 

int query_sum( int l , int r , int _l , int _r , int t ) {

    if ( ! t ) return 0 ;

    if ( l == _l && r == _r ) return S( t ) ;

    int mid = ( _l + _r ) >> 1 ;

    if ( r <= mid ) return query_sum( l , r , _l , mid , L( t ) ) ;

    if ( l > mid ) return query_sum( l , r , mid + 1 , _r , R( t ) ) ;

    return query_sum( l , mid , _l , mid , L( t ) ) + query_sum( mid + 1 , r , mid + 1 , _r , R( t ) ) ;

}

 

vector < int > E[ maxn ] ;

int arr[ maxn ] , num[ maxn ] , size[ maxn ] , child[ maxn ] , h[ maxn ] , first[ maxn ] ;

int up[ maxn ][ 21 ] , Index = 0 ;

 

int n , m , w[ maxn ] , c[ maxn ] , T[ maxn ] ;

 

void dfs0( int v , int u ) {

    size[ v ] = 1 , child[ v ] = 0 ;

    travel( v ) if ( *p != u ) {

        h[ *p ] = h[ v ] + 1 , up[ *p ][ 0 ] = v ;

        dfs0( *p , v ) ;

        size[ v ] += size[ *p ] ;

        if ( size[ child[ v ] ] < size[ *p ] ) child[ v ] = *p ;

    }

}

 

int dfs1( int v , int u , int fir ) {

    first[ v ] = fir ;

    arr[ num[ v ] = ++ Index ] = v ;

    if ( child[ v ] ) {

        dfs1( child[ v ] , v , fir ) ;

        travel( v ) if ( *p != u && *p != child[ v ] ) {

            dfs1( *p , v , *p ) ;

        }

    }

}

 

int Lca( int u , int v ) {

    if ( h[ u ] < h[ v ] ) swap( u , v ) ;

    for ( int i = 20 ; i >= 0 ; -- i ) {

        if ( h[ up[ u ][ i ] ] >= h[ v ] ) {

            u = up[ u ][ i ] ;

        }

    }

    if ( v == u ) return v ;

    for ( int i = 20 ; i >= 0 ; -- i ) {

        if ( up[ u ][ i ] != up[ v ][ i ] ) {

            u = up[ u ][ i ] , v = up[ v ][ i ] ;

        }

    }

    return up[ u ][ 0 ] ;

}

 

int Query_sum( int v , int lca , int C ) {

    int temp = 0 ;

    while ( 1 ) {

        if ( h[ lca ] < h[ first[ v ] ] ) {

            temp += query_sum( num[ first[ v ] ] , num[ v ] , 1 , n , T[ C ] ) ;

            v = up[ first[ v ] ][ 0 ] ;

        } else {

            temp += query_sum( num[ lca ] , num[ v ] , 1 , n , T[ C ] ) ;

            break ;

        }

    }

    return temp ;

}

 

int Query_max( int v , int lca , int C ) {

    int temp = 0 ;

    while ( 1 ) {

        if ( h[ lca ] < h[ first[ v ] ] ) {

            temp = max( temp , query_max( num[ first[ v ] ] , num[ v ] , 1 , n , T[ C ] ) ) ;

            v = up[ first[ v ] ][ 0 ] ;

        } else {

            temp = max( temp , query_max( num[ lca ] , num[ v ] , 1 , n , T[ C ] ) ) ;

            break ;

        }

    }

    return temp ;

}

 

int main(  ) {

    scanf( "%d%d" , &n , &m ) ;

    for ( int i = 0 ; i ++ < n ; ) scanf( "%d%d" , w + i , c + i ) ;

    for ( int i = 1 ; i < n ; ++ i ) {

        int s , t ; scanf( "%d%d" , &s , &t ) ;

        AddEdge( s , t ) ;

    }

    size[ 0 ] = 0 , h[ 1 ] = 1 ;

    memset( up , 0 , sizeof( up ) ) ;

    dfs0( 1 , 0 ) ;

    for ( int i = 0 ; i ++ < 20 ; ) {

        for ( int j = 0 ; j ++ < n ; ) {

            up[ j ][ i ] = up[ up[ j ][ i - 1 ] ][ i - 1 ] ;

        }

    }

    dfs1( 1 , 0 , 1 ) ;

    memset( T , 0 , sizeof( T ) ) ;

    for ( int i = 0 ; i ++ < n ; ) {

        Change( 1 , n , num[ i ] , w[ i ] , T[ c[ i ] ] ) ;

    }

    while ( m -- ) {

        char s[ 2 ] ; scanf( "%s" , s ) ;

        if ( s[ 0 ] == 'C' ) {

            if ( s[ 1 ] == 'C' ) {

                int x , y ; scanf( "%d%d" , &x , &y ) ;

                Change( 1 , n , num[ x ] , 0 , T[ c[ x ] ] ) ;

                Change( 1 , n , num[ x ] , w[ x ] , T[ c[ x ] = y ] ) ;

            } else {

                int x , y ; scanf( "%d%d" , &x , &y ) ;

                Change( 1 , n , num[ x ] , w[ x ] = y , T[ c[ x ] ] );

            }

        } else {

            int x , y , lca ; scanf( "%d%d" , &x , &y ) ;

            lca = Lca( x , y ) ;

            if ( s[ 1 ] == 'S' ) {

                int ans = Query_sum( x , lca , c[ y ] ) ;

                ans += Query_sum( y , lca , c[ y ] ) ;

                ans -= w[ lca ] * ( c[ lca ] == c[ y ] ) ;

                printf( "%d\n" , ans ) ;

            } else {

                int ans = Query_max( x , lca , c[ y ] ) ;

                ans = max( ans , Query_max( y , lca , c[ y ] ) ) ;

                printf( "%d\n" , ans ) ;

            }

        }

    }

    return 0 ;

}
上一篇下一篇

猜你喜欢

热点阅读