Android GradleAndroid ASM

Gradle插件<第三篇>:字节码插桩技术-替换类

2021-08-23  本文已影响0人  NoBugException

字节码插桩技术有:Javassist、BCEL、ASM,它们的执行速度比对图如下:

图片.png

ASM的执行速率要比Javassist和BCEL快的多,所以本文主要是使用ASM实现字节码插桩。

假设项目中自定义了一个叫MyApplication的Application,代码如下:

package com.nobug.classtest;

import android.app.Application;
import android.util.Log;

public class MyApplication extends Application {

    private String TAG = MyApplication.class.getSimpleName();

    @Override
    public void onCreate() {
        super.onCreate();
        Log.d(TAG, "my application create");
    }
}

通过字节码插桩的方式将com.nobug.classtest.MyApplication替换成com.nobug.classtest.MyCustomApplition。

将上面的Java代码转换成字节码之后:

// class version 51.0 (51)
// access flags 0x21
public class com/nobug/classtest/MyApplication extends android/app/Application  {

// compiled from: MyApplication.java

// access flags 0x2
private Ljava/lang/String; TAG

// access flags 0x1
public <init>()V
        L0
        LINENUMBER 6 L0
        ALOAD 0
        INVOKESPECIAL android/app/Application.<init> ()V
        L1
        LINENUMBER 8 L1
        ALOAD 0
        LDC Lcom/nobug/classtest/MyApplication;.class
INVOKEVIRTUAL java/lang/Class.getSimpleName ()Ljava/lang/String;
        PUTFIELD com/nobug/classtest/MyApplication.TAG : Ljava/lang/String;
        RETURN
        L2
        LOCALVARIABLE this Lcom/nobug/classtest/MyApplication; L0 L2 0
        MAXSTACK = 2
        MAXLOCALS = 1

// access flags 0x1
public onCreate()V
        L0
        LINENUMBER 12 L0
        ALOAD 0
        INVOKESPECIAL android/app/Application.onCreate ()V
        L1
        LINENUMBER 13 L1
        ALOAD 0
        GETFIELD com/nobug/classtest/MyApplication.TAG : Ljava/lang/String;
        LDC "my application create"
        INVOKESTATIC android/util/Log.d (Ljava/lang/String;Ljava/lang/String;)I
        POP
        L2
        LINENUMBER 14 L2
        RETURN
        L3
        LOCALVARIABLE this Lcom/nobug/classtest/MyApplication; L0 L3 0
        MAXSTACK = 2
        MAXLOCALS = 1
        }

下面需要做的是,将com/nobug/classtest/MyApplication替换成com.nobug.classtest.MyCustomApplition,将MyApplication换成MyCustomApplition

插件:

package com.nobug.plugintest

import com.android.build.api.transform.DirectoryInput
import com.android.build.api.transform.Format
import com.android.build.api.transform.JarInput
import com.android.build.api.transform.QualifiedContent
import com.android.build.api.transform.Transform
import com.android.build.api.transform.TransformException
import com.android.build.api.transform.TransformInput
import com.android.build.api.transform.TransformInvocation
import com.android.build.api.transform.TransformOutputProvider
import com.android.build.gradle.AppExtension
import com.android.build.gradle.internal.pipeline.TransformManager
import com.android.utils.FileUtils
import com.nobug.plugindemo.MyCustomClassVisitor
import org.apache.commons.codec.digest.DigestUtils
import org.apache.commons.io.IOUtils
import org.gradle.api.Plugin;
import org.gradle.api.Project
import org.objectweb.asm.ClassReader
import org.objectweb.asm.ClassVisitor
import org.objectweb.asm.ClassWriter
import java.util.jar.JarEntry
import java.util.jar.JarFile
import java.util.jar.JarOutputStream
import java.util.zip.ZipEntry;

/**
 * 插件测试
 */
class PluginTest extends Transform implements Plugin<Project> {

    Project project;

    @Override
    void apply(Project project) {
        this.project = project;
        def log = project.logger;
        def android = project.extensions.getByType(AppExtension)
        android.registerTransform(this)
    }


    @Override
    String getName() {
        return "ASMPlugin"
    }

    @Override
    Set<QualifiedContent.ContentType> getInputTypes() {
        return TransformManager.CONTENT_CLASS
    }

    @Override
    Set<? super QualifiedContent.Scope> getScopes() {
        return TransformManager.SCOPE_FULL_PROJECT
    }

    @Override
    boolean isIncremental() {
        return false
    }

    @Override
    void transform(TransformInvocation transformInvocation) throws TransformException, InterruptedException, IOException {
        //处理class
        Collection<TransformInput> inputs = transformInvocation.inputs
        TransformOutputProvider outputProvider = transformInvocation.outputProvider
        //删除旧的输出
        if (outputProvider != null) {
            outputProvider.deleteAll()
        }
        //遍历inputs
        inputs.each { input ->
            //遍历directoryInputs
            input.directoryInputs.each {
                directoryInput -> handleDirectoryInput(directoryInput, outputProvider)
            }
            //遍历jarInputs
            input.jarInputs.each {
                jarInput -> handleJarInput(jarInput, outputProvider)
            }
        }
    }

    /**
     * 处理目录下的class文件
     * @param directoryInput
     * @param outputProvider
     */
    void handleDirectoryInput(DirectoryInput directoryInput, TransformOutputProvider outputProvider) {
        //是否为目录
        if (directoryInput.file.isDirectory()) {
            //列出目录所有文件(包含子文件夹,子文件夹内文件)
            directoryInput.file.eachFileRecurse {
                file ->
                    def name = file.name
                    if (isClassFile(name)) {
                        ClassReader classReader = new ClassReader(file.bytes)
                        ClassWriter classWriter = new ClassWriter(classReader, ClassWriter.COMPUTE_MAXS)
                        ClassVisitor classVisitor = new MyCustomClassVisitor(classWriter)
                        classReader.accept(classVisitor, ClassReader.EXPAND_FRAMES)
                        byte[] bytes = classWriter.toByteArray()
                        FileOutputStream fileOutputStream = new FileOutputStream(file.parentFile.absolutePath + File.separator + name)
                        fileOutputStream.write(bytes)
                        fileOutputStream.close()
                    }
            }
        }
        def dest = outputProvider.getContentLocation(directoryInput.name, directoryInput.contentTypes, directoryInput.scopes, Format.DIRECTORY)
        FileUtils.copyDirectory(directoryInput.file, dest)
    }

    /**
     * 处理Jar中的class文件
     * @param jarInput
     * @param outputProvider
     */
    void handleJarInput(JarInput jarInput, TransformOutputProvider outputProvider) {
        if (jarInput.file.getAbsolutePath().endsWith(".jar")) {
            //重名名输出文件,因为可能同名,会覆盖
            def jarName = jarInput.name
            def md5Name = DigestUtils.md5Hex(jarInput.file.absolutePath)
            if (jarName.endsWith(".jar")) {
                jarName = jarName.substring(0, jarName.length() - 4)
            }
            JarFile jarFile = new JarFile(jarInput.file)
            Enumeration enumeration = jarFile.entries()
            File tempFile = new File(jarInput.file.parent + File.separator + "temp.jar")
            //避免上次的缓存被重复插入
            if (tempFile.exists()) {
                tempFile.delete()
            }
            JarOutputStream jarOutputStream = new JarOutputStream(new FileOutputStream(tempFile))
            //保存
            while (enumeration.hasMoreElements()) {
                JarEntry jarEntry = enumeration.nextElement()
                String entryName = jarEntry.name
                ZipEntry zipEntry = new ZipEntry(entryName)
                InputStream inputStream = jarFile.getInputStream(zipEntry)
                if (isClassFile(entryName)) {
                    jarOutputStream.putNextEntry(zipEntry)
                    ClassReader classReader = new ClassReader(IOUtils.toByteArray(inputStream))
                    ClassWriter classWriter = new ClassWriter(classReader, ClassWriter.COMPUTE_MAXS)
                    ClassVisitor classVisitor = new MyCustomClassVisitor(classWriter)
                    classReader.accept(classVisitor, ClassReader.EXPAND_FRAMES)
                    byte[] bytes = classWriter.toByteArray()
                    jarOutputStream.write(bytes)
                } else {
                    jarOutputStream.putNextEntry(zipEntry)
                    jarOutputStream.write(IOUtils.toByteArray(inputStream))
                }
                jarOutputStream.closeEntry()
            }

            jarOutputStream.close()
            jarFile.close()
            def dest = outputProvider.getContentLocation(jarName + "_" + md5Name, jarInput.contentTypes, jarInput.scopes, Format.JAR)
            FileUtils.copyFile(tempFile, dest)
            tempFile.delete()
        }
    }

    /**
     * 判断是否为需要处理class文件
     * @param name
     * @return
     */
    boolean isClassFile(String name) {
        return (name.endsWith(".class") && !name.startsWith("R\$")
                && "R.class" != name && "BuildConfig.class" != name)
    }
}

MyCustomClassVisitor.groovy

package com.nobug.plugindemo;

import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.FieldVisitor;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;

class MyCustomClassVisitor extends ClassVisitor implements Opcodes {

    private static final String MY_APPLITION_NAME = "com/nobug/classtest/MyApplication";
    private static final String REPLACE_APPLITION_NAME = "com/nobug/classtest/MyCustomApplition";
    private static final String JAVA_APPLITION_NAME = "MyApplication.java";
    private static final String JAVA_REPLACE_APPLITION_NAME = "MyCustomApplition.java";

    public MyCustomClassVisitor(ClassVisitor cv) {
        super(Opcodes.ASM5, cv);
    }

    @Override
    public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
        if (MY_APPLITION_NAME.equals(name)) {
            name = REPLACE_APPLITION_NAME;
        }
        if (MY_APPLITION_NAME.equals(superName)) {
            superName = REPLACE_APPLITION_NAME;
        }
         super.visit(version, access, name, signature, superName, interfaces);
    }

    @Override
    public void visitSource(String source, String debug) {
        if (JAVA_APPLITION_NAME.equals(source)) {
            source = JAVA_REPLACE_APPLITION_NAME;
        }
        super.visitSource(source, debug);
    }

    @Override
    public FieldVisitor visitField(int access, String name, String descriptor, String signature, Object value) {
        String replaceDesc = descriptor;
        if (replaceDesc.contains(MY_APPLITION_NAME)) {
            replaceDesc = replaceDesc.replace(MY_APPLITION_NAME, REPLACE_APPLITION_NAME);
        }
        return super.visitField(access, name, replaceDesc, signature, value);
    }

    @Override
    public MethodVisitor visitMethod(int access, String name, String desc, String signature, String[] exceptions) {
        String replaceDesc = desc;
        if (replaceDesc.contains(MY_APPLITION_NAME)) {
            replaceDesc = replaceDesc.replace(MY_APPLITION_NAME, REPLACE_APPLITION_NAME);
        }
        MethodVisitor methodVisitor = cv.visitMethod(access, name, replaceDesc, signature, exceptions);
        return new MyCustomMethodVisitor(Opcodes.ASM5, methodVisitor);
    }
}

MyCustomMethodVisitor.groovy

package com.nobug.plugindemo;

import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Type;

public class MyCustomMethodVisitor extends MethodVisitor {

    private static final String MY_APPLITION_NAME = "com/nobug/classtest/MyApplication";
    private static final String REPLACE_APPLITION_NAME = "com/nobug/classtest/MyCustomApplition";

    public MyCustomMethodVisitor(int api, MethodVisitor mv) {
        super(api, mv);
    }

    @Override
    public void visitLdcInsn(Object value) {
        if (value instanceof Type) {
            String typeDescriptor = ((Type) value).getDescriptor();
            if (typeDescriptor.contains(MY_APPLITION_NAME)) {
                typeDescriptor = typeDescriptor.replace(MY_APPLITION_NAME, REPLACE_APPLITION_NAME);
                value = Type.getType(typeDescriptor);
            }
        }
        super.visitLdcInsn(value);
    }

    @Override
    public void visitFieldInsn(int opcode, String owner, String name, String descriptor) {
        if (MY_APPLITION_NAME.equals(owner)) {
            owner = REPLACE_APPLITION_NAME;
        }
        super.visitFieldInsn(opcode, owner, name, descriptor);
    }

    @Override
    public void visitLocalVariable(String name, String descriptor, String signature, Label start, Label end, int index) {
        if (descriptor.contains(MY_APPLITION_NAME)) {
            descriptor = descriptor.replace(MY_APPLITION_NAME, REPLACE_APPLITION_NAME);
        }
        super.visitLocalVariable(name, descriptor, signature, start, end, index);
    }
}

最后,别忘了声明Application的名称,如图:

图片.png

[本章完...]

上一篇 下一篇

猜你喜欢

热点阅读