41 - ASM之优化、删除等复杂的变换

2022-02-21  本文已影响0人  舍是境界

复杂的变换

stateless transformations

The stateless transformation does not depend on the instructions that have been visited before the current one.

举几个关于stateless transformation的例子:

这种stateless transformation实现起来比较容易,所以也被称为simple transformations。

stateful transformations

The stateful transformation require memorizing some state about the instructions that have been visited before the current one.
This requires storing state inside the method adapter.

举几个关于stateful transformation的例子:

这种stateful transformation实现起来比较困难,所以也被称为complex transformations。

那么,为什么stateless transformation实现起来比较容易,而stateful transformation会实现起来比较困难呢?做个类比,stateless transformation就类似于“一人吃饱,全家不饿”,不用考虑太多,所以实现起来就比较简单;而stateful transformation类似于“成家之后,要考虑一家人的生活状态”,考虑的事情就多一点,所以实现起来就比较困难。难归难,但是我们还是应该想办法进行实现。

那么,stateful transformation到底该如何开始着手实现呢?在stateful transformation过程中,一般都是涉及到对多个指令(Instruction)同时判断,这多个指令是一个“组合”,不能轻易拆散。我们通过三个步骤来进行实现:

到这里,就有一个新的问题产生:如何去记录第二步当中的状态(state)变化呢?我们的回答就是,借助于state machine。

state machine

什么是state machine?

A state machine is a behavior model. It consists of a finite number of states and is therefore also called finite-state machine (FSM). Based on the current state and a given input the machine performs state transitions and produces outputs.

state machine的聪明之处,就是将“无限”的操作步骤给限定在“有限”的状态里来思考。

接下来,就是给出一个具体的state machine。也就是说,下面的MethodPatternAdapter类,就是一个原始的state machine,我们从三个层面来把握它:

那么,应该怎么使用MethodPatternAdapter类呢?我们就是写一个MethodPatternAdapter类的子类,这个子类就是一个更“先进”的state machine,它做以下三件事情:

import org.objectweb.asm.Handle;
import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;

public abstract class MethodPatternAdapter extends MethodVisitor {
    protected final static int SEEN_NOTHING = 0;
    protected int state;

    public MethodPatternAdapter(int api, MethodVisitor methodVisitor) {
        super(api, methodVisitor);
    }

    @Override
    public void visitInsn(int opcode) {
        visitInsn();
        super.visitInsn(opcode);
    }

    @Override
    public void visitIntInsn(int opcode, int operand) {
        visitInsn();
        super.visitIntInsn(opcode, operand);
    }

    @Override
    public void visitVarInsn(int opcode, int var) {
        visitInsn();
        super.visitVarInsn(opcode, var);
    }

    @Override
    public void visitTypeInsn(int opcode, String type) {
        visitInsn();
        super.visitTypeInsn(opcode, type);
    }

    @Override
    public void visitFieldInsn(int opcode, String owner, String name, String descriptor) {
        visitInsn();
        super.visitFieldInsn(opcode, owner, name, descriptor);
    }

    @Override
    public void visitMethodInsn(int opcode, String owner, String name, String descriptor) {
        visitInsn();
        super.visitMethodInsn(opcode, owner, name, descriptor);
    }

    @Override
    public void visitMethodInsn(int opcode, String owner, String name, String descriptor, boolean isInterface) {
        visitInsn();
        super.visitMethodInsn(opcode, owner, name, descriptor, isInterface);
    }

    @Override
    public void visitInvokeDynamicInsn(String name, String descriptor, Handle bootstrapMethodHandle, Object... bootstrapMethodArguments) {
        visitInsn();
        super.visitInvokeDynamicInsn(name, descriptor, bootstrapMethodHandle, bootstrapMethodArguments);
    }

    @Override
    public void visitJumpInsn(int opcode, Label label) {
        visitInsn();
        super.visitJumpInsn(opcode, label);
    }

    @Override
    public void visitLdcInsn(Object value) {
        visitInsn();
        super.visitLdcInsn(value);
    }

    @Override
    public void visitIincInsn(int var, int increment) {
        visitInsn();
        super.visitIincInsn(var, increment);
    }

    @Override
    public void visitTableSwitchInsn(int min, int max, Label dflt, Label... labels) {
        visitInsn();
        super.visitTableSwitchInsn(min, max, dflt, labels);
    }

    @Override
    public void visitLookupSwitchInsn(Label dflt, int[] keys, Label[] labels) {
        visitInsn();
        super.visitLookupSwitchInsn(dflt, keys, labels);
    }

    @Override
    public void visitMultiANewArrayInsn(String descriptor, int numDimensions) {
        visitInsn();
        super.visitMultiANewArrayInsn(descriptor, numDimensions);
    }

    @Override
    public void visitTryCatchBlock(Label start, Label end, Label handler, String type) {
        visitInsn();
        super.visitTryCatchBlock(start, end, handler, type);
    }

    @Override
    public void visitLabel(Label label) {
        visitInsn();
        super.visitLabel(label);
    }

    @Override
    public void visitFrame(int type, int numLocal, Object[] local, int numStack, Object[] stack) {
        visitInsn();
        super.visitFrame(type, numLocal, local, numStack, stack);
    }

    @Override
    public void visitMaxs(int maxStack, int maxLocals) {
        visitInsn();
        super.visitMaxs(maxStack, maxLocals);
    }

    protected abstract void visitInsn();
}

示例一:加零

预期目标

假如有一个HelloWorld类,代码如下:

public class HelloWorld {
    public void test(int a, int b) {
        int c = a + b;
        int d = c + 0;
        System.out.println(d);
    }
}

我们想要实现的预期目标:将int d = c + 0;转换成int d = c;。

$ javap -c sample.HelloWorld
Compiled from "HelloWorld.java"
public class sample.HelloWorld {
...
  public void test(int, int);
    Code:
       0: iload_1
       1: iload_2
       2: iadd
       3: istore_3
       4: iload_3
       5: iconst_0
       6: iadd
       7: istore        4
       9: getstatic     #2                  // Field java/lang/System.out:Ljava/io/PrintStream;
      12: iload         4
      14: invokevirtual #3                  // Method java/io/PrintStream.println:(I)V
      17: return
}

编码实现

import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.MethodVisitor;

import static org.objectweb.asm.Opcodes.*;

public class MethodRemoveAddZeroVisitor extends ClassVisitor {
    public MethodRemoveAddZeroVisitor(int api, ClassVisitor classVisitor) {
        super(api, classVisitor);
    }

    @Override
    public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
        MethodVisitor mv = cv.visitMethod(access, name, descriptor, signature, exceptions);
        if (mv != null && !"<init>".equals(name) && !"<clinit>".equals(name)) {
            boolean isAbstractMethod = (access & ACC_ABSTRACT) != 0;
            boolean isNativeMethod = (access & ACC_NATIVE) != 0;
            if (!isAbstractMethod && !isNativeMethod) {
                mv = new MethodRemoveAddZeroAdapter(api, mv);
            }
        }
        return mv;
    }

    private class MethodRemoveAddZeroAdapter extends MethodPatternAdapter {
        private static final int SEEN_ICONST_0 = 1;

        public MethodRemoveAddZeroAdapter(int api, MethodVisitor methodVisitor) {
            super(api, methodVisitor);
        }

        @Override
        public void visitInsn(int opcode) {
            // 第一,对于感兴趣的状态进行处理
            switch (state) {
                case SEEN_NOTHING:
                    if (opcode == ICONST_0) {
                        state = SEEN_ICONST_0;
                        return;
                    }
                    break;
                case SEEN_ICONST_0:
                    if (opcode == IADD) {
                        state = SEEN_NOTHING;
                        return;
                    }
                    else if (opcode == ICONST_0) {
                        mv.visitInsn(ICONST_0);
                        return;
                    }
                    break;
            }

            // 第二,对于不感兴趣的状态,交给父类进行处理
            super.visitInsn(opcode);
        }

        @Override
        protected void visitInsn() {
            if (state == SEEN_ICONST_0) {
                mv.visitInsn(ICONST_0);
            }
            state = SEEN_NOTHING;
        }
    }
}

进行转换

import lsieun.utils.FileUtils;
import org.objectweb.asm.*;

public class HelloWorldTransformCore {
    public static void main(String[] args) {
        String relative_path = "sample/HelloWorld.class";
        String filepath = FileUtils.getFilePath(relative_path);
        byte[] bytes1 = FileUtils.readBytes(filepath);

        //(1)构建ClassReader
        ClassReader cr = new ClassReader(bytes1);

        //(2)构建ClassWriter
        ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES);

        //(3)串连ClassVisitor
        int api = Opcodes.ASM9;
        ClassVisitor cv = new MethodRemoveAddZeroVisitor(api, cw);

        //(4)结合ClassReader和ClassVisitor
        int parsingOptions = ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES;
        cr.accept(cv, parsingOptions);

        //(5)生成byte[]
        byte[] bytes2 = cw.toByteArray();

        FileUtils.writeBytes(filepath, bytes2);
    }
}

验证结果

$ javap -c sample.HelloWorld
public class sample.HelloWorld {
...
  public void test(int, int);
    Code:
       0: iload_1
       1: iload_2
       2: iadd
       3: istore_3
       4: iload_3
       5: istore        4
       7: getstatic     #16                 // Field java/lang/System.out:Ljava/io/PrintStream;
      10: iload         4
      12: invokevirtual #22                 // Method java/io/PrintStream.println:(I)V
      15: return
}

示例二:字段赋值

预期目标

假如有一个HelloWorld类,代码如下:

public class HelloWorld {
    public int val;

    public void test(int a, int b) {
        int c = a + b;
        this.val = this.val;
        System.out.println(c);
    }
}

我们想要实现的预期目标:删除掉this.val = this.val;语句。

$ javap -c sample.HelloWorld
Compiled from "HelloWorld.java"
public class sample.HelloWorld {
  public int val;

...

  public void test(int, int);
    Code:
       0: iload_1
       1: iload_2
       2: iadd
       3: istore_3
       4: aload_0
       5: aload_0
       6: getfield      #2                  // Field val:I
       9: putfield      #2                  // Field val:I
      12: getstatic     #3                  // Field java/lang/System.out:Ljava/io/PrintStream;
      15: iload_3
      16: invokevirtual #4                  // Method java/io/PrintStream.println:(I)V
      19: return
}

编码实现

状态示意图
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.MethodVisitor;

import static org.objectweb.asm.Opcodes.*;

public class MethodRemoveGetFieldPutFieldVisitor extends ClassVisitor {
    public MethodRemoveGetFieldPutFieldVisitor(int api, ClassVisitor classVisitor) {
        super(api, classVisitor);
    }

    @Override
    public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
        MethodVisitor mv = cv.visitMethod(access, name, descriptor, signature, exceptions);
        if (mv != null && !"<init>".equals(name) && !"<clinit>".equals(name)) {
            boolean isAbstractMethod = (access & ACC_ABSTRACT) != 0;
            boolean isNativeMethod = (access & ACC_NATIVE) != 0;
            if (!isAbstractMethod && !isNativeMethod) {
                mv = new MethodRemoveGetFieldPutFieldAdapter(api, mv);
            }
        }
        return mv;
    }

    private class MethodRemoveGetFieldPutFieldAdapter extends MethodPatternAdapter {
        private final static int SEEN_ALOAD_0 = 1;
        private final static int SEEN_ALOAD_0_ALOAD_0 = 2;
        private final static int SEEN_ALOAD_0_ALOAD_0_GETFIELD = 3;

        private String fieldOwner;
        private String fieldName;
        private String fieldDesc;

        public MethodRemoveGetFieldPutFieldAdapter(int api, MethodVisitor methodVisitor) {
            super(api, methodVisitor);
        }

        @Override
        public void visitVarInsn(int opcode, int var) {
            // 第一,对于感兴趣的状态进行处理
            switch (state) {
                case SEEN_NOTHING:
                    if (opcode == ALOAD && var == 0) {
                        state = SEEN_ALOAD_0;
                        return;
                    }
                    break;
                case SEEN_ALOAD_0:
                    if (opcode == ALOAD && var == 0) {
                        state = SEEN_ALOAD_0_ALOAD_0;
                        return;
                    }
                    break;
                case SEEN_ALOAD_0_ALOAD_0:
                    if (opcode == ALOAD && var == 0) {
                        mv.visitVarInsn(opcode, var);
                        return;
                    }
                    break;
            }

            // 第二,对于不感兴趣的状态,交给父类进行处理
            super.visitVarInsn(opcode, var);
        }

        @Override
        public void visitFieldInsn(int opcode, String owner, String name, String descriptor) {
            // 第一,对于感兴趣的状态进行处理
            switch (state) {
                case SEEN_ALOAD_0_ALOAD_0:
                    if (opcode == GETFIELD) {
                        state = SEEN_ALOAD_0_ALOAD_0_GETFIELD;
                        fieldOwner = owner;
                        fieldName = name;
                        fieldDesc = descriptor;
                        return;
                    }
                    break;
                case SEEN_ALOAD_0_ALOAD_0_GETFIELD:
                    if (opcode == PUTFIELD && name.equals(fieldName)) {
                        state = SEEN_NOTHING;
                        return;
                    }
                    break;
            }

            // 第二,对于不感兴趣的状态,交给父类进行处理
            super.visitFieldInsn(opcode, owner, name, descriptor);
        }

        @Override
        protected void visitInsn() {
            switch (state) {
                case SEEN_ALOAD_0:
                    mv.visitVarInsn(ALOAD, 0);
                    break;
                case SEEN_ALOAD_0_ALOAD_0:
                    mv.visitVarInsn(ALOAD, 0);
                    mv.visitVarInsn(ALOAD, 0);
                    break;
                case SEEN_ALOAD_0_ALOAD_0_GETFIELD:
                    mv.visitVarInsn(ALOAD, 0);
                    mv.visitVarInsn(ALOAD, 0);
                    mv.visitFieldInsn(GETFIELD, fieldOwner, fieldName, fieldDesc);
                    break;
            }
            state = SEEN_NOTHING;
        }
    }
}

进行转换

import lsieun.utils.FileUtils;
import org.objectweb.asm.*;

public class HelloWorldTransformCore {
    public static void main(String[] args) {
        String relative_path = "sample/HelloWorld.class";
        String filepath = FileUtils.getFilePath(relative_path);
        byte[] bytes1 = FileUtils.readBytes(filepath);

        //(1)构建ClassReader
        ClassReader cr = new ClassReader(bytes1);

        //(2)构建ClassWriter
        ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES);

        //(3)串连ClassVisitor
        int api = Opcodes.ASM9;
        ClassVisitor cv = new MethodRemoveGetFieldPutFieldVisitor(api, cw);

        //(4)结合ClassReader和ClassVisitor
        int parsingOptions = ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES;
        cr.accept(cv, parsingOptions);

        //(5)生成byte[]
        byte[] bytes2 = cw.toByteArray();

        FileUtils.writeBytes(filepath, bytes2);
    }
}

验证结果

$ javap -c sample.HelloWorld
public class sample.HelloWorld {
  public int val;

  public sample.HelloWorld();
    Code:
       0: aload_0
       1: invokespecial #10                 // Method java/lang/Object."<init>":()V
       4: return

  public void test(int, int);
    Code:
       0: iload_1
       1: iload_2
       2: iadd
       3: istore_3
       4: getstatic     #18                 // Field java/lang/System.out:Ljava/io/PrintStream;
       7: iload_3
       8: invokevirtual #24                 // Method java/io/PrintStream.println:(I)V
      11: return
}

示例三:删除打印语句

预期目标

假如有一个HelloWorld类,代码如下:

public class HelloWorld {
    public void test(int a, int b) {
        System.out.println("Before a + b");
        int c = a + b;
        System.out.println("After a + b");
        System.out.println(c);
    }
}

我们想要实现的预期目标:删除掉打印字符串的语句。

$ javap -c sample.HelloWorld
Compiled from "HelloWorld.java"
public class sample.HelloWorld {
...
  public void test(int, int);
    Code:
       0: getstatic     #2                  // Field java/lang/System.out:Ljava/io/PrintStream;
       3: ldc           #3                  // String Before a + b
       5: invokevirtual #4                  // Method java/io/PrintStream.println:(Ljava/lang/String;)V
       8: iload_1
       9: iload_2
      10: iadd
      11: istore_3
      12: getstatic     #2                  // Field java/lang/System.out:Ljava/io/PrintStream;
      15: ldc           #5                  // String After a + b
      17: invokevirtual #4                  // Method java/io/PrintStream.println:(Ljava/lang/String;)V
      20: getstatic     #2                  // Field java/lang/System.out:Ljava/io/PrintStream;
      23: iload_3
      24: invokevirtual #6                  // Method java/io/PrintStream.println:(I)V
      27: return
}

编码实现

import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.MethodVisitor;

import static org.objectweb.asm.Opcodes.*;

public class MethodRemovePrintVisitor extends ClassVisitor {
    public MethodRemovePrintVisitor(int api, ClassVisitor classVisitor) {
        super(api, classVisitor);
    }

    @Override
    public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
        MethodVisitor mv = cv.visitMethod(access, name, descriptor, signature, exceptions);
        if (mv != null && !"<init>".equals(name) && !"<clinit>".equals(name)) {
            boolean isAbstractMethod = (access & ACC_ABSTRACT) != 0;
            boolean isNativeMethod = (access & ACC_NATIVE) != 0;
            if (!isAbstractMethod && !isNativeMethod) {
                mv = new MethodRemovePrintAdaptor(api, mv);
            }
        }
        return mv;
    }

    private class MethodRemovePrintAdaptor extends MethodPatternAdapter {
        private static final int SEEN_GETSTATIC = 1;
        private static final int SEEN_GETSTATIC_LDC = 2;

        private String message;

        public MethodRemovePrintAdaptor(int api, MethodVisitor methodVisitor) {
            super(api, methodVisitor);
        }

        @Override
        public void visitFieldInsn(int opcode, String owner, String name, String descriptor) {
            // 第一,对于感兴趣的状态进行处理
            boolean flag = (opcode == GETSTATIC && owner.equals("java/lang/System") && name.equals("out") 
                           && descriptor.equals("Ljava/io/PrintStream;"));
            switch (state) {
                case SEEN_NOTHING:
                    if (flag) {
                        state = SEEN_GETSTATIC;
                        return;
                    }
                    break;
                case SEEN_GETSTATIC:
                    if (flag) {
                        mv.visitFieldInsn(GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;");
                        return;
                    }
            }

            // 第二,对于不感兴趣的状态,交给父类进行处理
            super.visitFieldInsn(opcode, owner, name, descriptor);
        }

        @Override
        public void visitLdcInsn(Object value) {
            // 第一,对于感兴趣的状态进行处理
            switch (state) {
                case SEEN_GETSTATIC:
                    if (value instanceof String) {
                        state = SEEN_GETSTATIC_LDC;
                        message = (String) value;
                        return;
                    }
                    break;
            }

            // 第二,对于不感兴趣的状态,交给父类进行处理
            super.visitLdcInsn(value);
        }

        @Override
        public void visitMethodInsn(int opcode, String owner, String name, String descriptor, boolean isInterface) {
            // 第一,对于感兴趣的状态进行处理
            switch (state) {
                case SEEN_GETSTATIC_LDC:
                    if (opcode == INVOKEVIRTUAL && owner.equals("java/io/PrintStream") &&
                            name.equals("println") && descriptor.equals("(Ljava/lang/String;)V")) {
                        state = SEEN_NOTHING;
                        return;
                    }
                    break;
            }

            // 第二,对于不感兴趣的状态,交给父类进行处理
            super.visitMethodInsn(opcode, owner, name, descriptor, isInterface);
        }

        @Override
        protected void visitInsn() {
            switch (state) {
                case SEEN_GETSTATIC:
                    mv.visitFieldInsn(GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;");
                    break;
                case SEEN_GETSTATIC_LDC:
                    mv.visitFieldInsn(GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;");
                    mv.visitLdcInsn(message);
                    break;
            }

            state = SEEN_NOTHING;
        }
    }
}

进行转换

import lsieun.utils.FileUtils;
import org.objectweb.asm.*;

public class HelloWorldTransformCore {
    public static void main(String[] args) {
        String relative_path = "sample/HelloWorld.class";
        String filepath = FileUtils.getFilePath(relative_path);
        byte[] bytes1 = FileUtils.readBytes(filepath);

        //(1)构建ClassReader
        ClassReader cr = new ClassReader(bytes1);

        //(2)构建ClassWriter
        ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES);

        //(3)串连ClassVisitor
        int api = Opcodes.ASM9;
        ClassVisitor cv = new MethodRemovePrintVisitor(api, cw);

        //(4)结合ClassReader和ClassVisitor
        int parsingOptions = ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES;
        cr.accept(cv, parsingOptions);

        //(5)生成byte[]
        byte[] bytes2 = cw.toByteArray();

        FileUtils.writeBytes(filepath, bytes2);
    }
}

验证结果

$ javap -c sample.HelloWorld
public class sample.HelloWorld {
...
  public void test(int, int);
    Code:
       0: iload_1
       1: iload_2
       2: iadd
       3: istore_3
       4: getstatic     #16                 // Field java/lang/System.out:Ljava/io/PrintStream;
       7: iload_3
       8: invokevirtual #22                 // Method java/io/PrintStream.println:(I)V
      11: return
}

小结

本文对stateful transformations进行介绍,内容总结如下:

上一篇 下一篇

猜你喜欢

热点阅读