36 - ASM之方法计时

2022-02-16  本文已影响0人  舍是境界

预期目标

假如有一个HelloWorld类,代码如下:

import java.util.Random;

public class HelloWorld {

    public int add(int a, int b) throws InterruptedException {
        int c = a + b;
        Random rand = new Random(System.currentTimeMillis());
        int num = rand.nextInt(300);
        Thread.sleep(100 + num);
        return c;
    }

    public int sub(int a, int b) throws InterruptedException {
        int c = a - b;
        Random rand = new Random(System.currentTimeMillis());
        int num = rand.nextInt(400);
        Thread.sleep(100 + num);
        return c;
    }
}

我们想实现的预期目标:计算出方法的运行时间。这里有两种实现方式:

第一种方式,计算所有方法的运行时间,将该时间记录在timer字段当中:

import java.util.Random;

public class HelloWorld {
    public static long timer;

    public int add(int a, int b) throws InterruptedException {
        timer -= System.currentTimeMillis();
        int c = a + b;
        Random rand = new Random(System.currentTimeMillis());
        int num = rand.nextInt(300);
        Thread.sleep(100 + num);
        timer += System.currentTimeMillis();
        return c;
    }

    public int sub(int a, int b) throws InterruptedException {
        timer -= System.currentTimeMillis();
        int c = a - b;
        Random rand = new Random(System.currentTimeMillis());
        int num = rand.nextInt(400);
        Thread.sleep(100 + num);
        timer += System.currentTimeMillis();
        return c;
    }
}

第二种方式,计算每个方法的运行时间,将每个方法的运行时间单独记录在对应的字段当中:

import java.util.Random;

public class HelloWorld {
    public static long timer_add;
    public static long timer_sub;

    public int add(int a, int b) throws InterruptedException {
        timer_add -= System.currentTimeMillis();
        int c = a + b;
        Random rand = new Random(System.currentTimeMillis());
        int num = rand.nextInt(300);
        Thread.sleep(100 + num);
        timer_add += System.currentTimeMillis();
        return c;
    }

    public int sub(int a, int b) throws InterruptedException {
        timer_sub -= System.currentTimeMillis();
        int c = a - b;
        Random rand = new Random(System.currentTimeMillis());
        int num = rand.nextInt(400);
        Thread.sleep(100 + num);
        timer_sub += System.currentTimeMillis();
        return c;
    }
}

实现这个功能的思路:在“方法进入”的时候,减去一个时间戳;在“方法退出”的时候,加上一个时间戳,在这个过程当中就记录一个时间差。

有一个问题,我们为什么要计算方法的运行时间呢?如果我们想对现有的程序进行优化,那么需要对程序的整体性能有所了解,而方法的运行时间是衡量程序性能的一个重要参考。

第一种实现方式

import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.FieldVisitor;
import org.objectweb.asm.MethodVisitor;

import static org.objectweb.asm.Opcodes.*;

public class MethodTimerVisitor extends ClassVisitor {
    private String owner;
    private boolean isInterface;

    public MethodTimerVisitor(int api, ClassVisitor classVisitor) {
        super(api, classVisitor);
    }

    @Override
    public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
        super.visit(version, access, name, signature, superName, interfaces);
        owner = name;
        isInterface = (access & ACC_INTERFACE) != 0;
    }

    @Override
    public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
        MethodVisitor mv = super.visitMethod(access, name, descriptor, signature, exceptions);
        if (!isInterface && mv != null && !"<init>".equals(name) && !"<clinit>".equals(name)) {
            boolean isAbstractMethod = (access & ACC_ABSTRACT) != 0;
            boolean isNativeMethod = (access & ACC_NATIVE) != 0;
            if (!isAbstractMethod && !isNativeMethod) {
                mv = new MethodTimerAdapter(api, mv, owner);
            }
        }
        return mv;
    }

    @Override
    public void visitEnd() {
        if (!isInterface) {
            FieldVisitor fv = super.visitField(ACC_PUBLIC | ACC_STATIC, "timer", "J", null, null);
            if (fv != null) {
                fv.visitEnd();
            }
        }
        super.visitEnd();
    }

    private static class MethodTimerAdapter extends MethodVisitor {
        private final String owner;

        public MethodTimerAdapter(int api, MethodVisitor mv, String owner) {
            super(api, mv);
            this.owner = owner;
        }

        @Override
        public void visitCode() {
            // 首先,处理自己的代码逻辑
            super.visitFieldInsn(GETSTATIC, owner, "timer", "J");
            super.visitMethodInsn(INVOKESTATIC, "java/lang/System", "currentTimeMillis", "()J", false);
            super.visitInsn(LSUB);
            super.visitFieldInsn(PUTSTATIC, owner, "timer", "J");

            // 其次,调用父类的方法实现
            super.visitCode();
        }

        @Override
        public void visitInsn(int opcode) {
            // 首先,处理自己的代码逻辑
            if ((opcode >= IRETURN && opcode <= RETURN) || opcode == ATHROW) {
                super.visitFieldInsn(GETSTATIC, owner, "timer", "J");
                super.visitMethodInsn(INVOKESTATIC, "java/lang/System", "currentTimeMillis", "()J", false);
                super.visitInsn(LADD);
                super.visitFieldInsn(PUTSTATIC, owner, "timer", "J");
            }

            // 其次,调用父类的方法实现
            super.visitInsn(opcode);
        }
    }
}
import lsieun.utils.FileUtils;
import org.objectweb.asm.*;

public class HelloWorldTransformCore {
    public static void main(String[] args) {
        String relative_path = "sample/HelloWorld.class";
        String filepath = FileUtils.getFilePath(relative_path);
        byte[] bytes1 = FileUtils.readBytes(filepath);

        //(1)构建ClassReader
        ClassReader cr = new ClassReader(bytes1);

        //(2)构建ClassWriter
        ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES);

        //(3)串连ClassVisitor
        int api = Opcodes.ASM9;
        ClassVisitor cv = new MethodTimerVisitor(api, cw);

        //(4)结合ClassReader和ClassVisitor
        int parsingOptions = ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES;
        cr.accept(cv, parsingOptions);

        //(5)生成byte[]
        byte[] bytes2 = cw.toByteArray();

        FileUtils.writeBytes(filepath, bytes2);
    }
}
import java.lang.reflect.Field;
import java.util.Random;

public class HelloWorldRun {
    public static void main(String[] args) throws Exception {
        // 第一部分,先让“子弹飞一会儿”,让程序运行一段时间
        HelloWorld instance = new HelloWorld();
        Random rand = new Random(System.currentTimeMillis());
        for (int i = 0; i < 10; i++) {
            boolean flag = rand.nextBoolean();
            int a = rand.nextInt(50);
            int b = rand.nextInt(50);
            if (flag) {
                int c = instance.add(a, b);
                String line = String.format("%d + %d = %d", a, b, c);
                System.out.println(line);
            }
            else {
                int c = instance.sub(a, b);
                String line = String.format("%d - %d = %d", a, b, c);
                System.out.println(line);
            }
        }

        // 第二部分,来查看方法运行的时间
        Class<?> clazz = HelloWorld.class;
        Field[] declaredFields = clazz.getDeclaredFields();
        for (Field f : declaredFields) {
            String fieldName = f.getName();
            f.setAccessible(true);
            if (fieldName.startsWith("timer")) {
                Object FieldValue = f.get(null);
                System.out.println(fieldName + " = " + FieldValue);
            }
        }
    }
}

第二种实现方式

import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.FieldVisitor;
import org.objectweb.asm.MethodVisitor;

import static org.objectweb.asm.Opcodes.*;

public class MethodTimerVisitor2 extends ClassVisitor {
    private String owner;
    private boolean isInterface;

    public MethodTimerVisitor2(int api, ClassVisitor classVisitor) {
        super(api, classVisitor);
    }

    @Override
    public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
        super.visit(version, access, name, signature, superName, interfaces);
        owner = name;
        isInterface = (access & ACC_INTERFACE) != 0;
    }

    @Override
    public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
        MethodVisitor mv = super.visitMethod(access, name, descriptor, signature, exceptions);
        if (!isInterface && mv != null && !"<init>".equals(name) && !"<clinit>".equals(name)) {
            boolean isAbstractMethod = (access & ACC_ABSTRACT) != 0;
            boolean isNativeMethod = (access & ACC_NATIVE) != 0;
            if (!isAbstractMethod && !isNativeMethod) {
                // 每遇到一个合适的方法,就添加一个相应的字段
                FieldVisitor fv = super.visitField(ACC_PUBLIC | ACC_STATIC, getFieldName(name), "J", null, null);
                if (fv != null) {
                    fv.visitEnd();
                }

                mv = new MethodTimerAdapter2(api, mv, owner, name);
            }
            
        }
        return mv;
    }


    private String getFieldName(String methodName) {
        return "timer_" + methodName;
    }

    private class MethodTimerAdapter2 extends MethodVisitor {
        private final String owner;
        private final String methodName;

        public MethodTimerAdapter2(int api, MethodVisitor mv, String owner, String methodName) {
            super(api, mv);
            this.owner = owner;
            this.methodName = methodName;
        }

        @Override
        public void visitCode() {
            // 首先,处理自己的代码逻辑
            super.visitFieldInsn(GETSTATIC, owner, getFieldName(methodName), "J"); // 注意,字段名字要对应
            super.visitMethodInsn(INVOKESTATIC, "java/lang/System", "currentTimeMillis", "()J", false);
            super.visitInsn(LSUB);
            super.visitFieldInsn(PUTSTATIC, owner, getFieldName(methodName), "J"); // 注意,字段名字要对应

            // 其次,调用父类的方法实现
            super.visitCode();
        }

        @Override
        public void visitInsn(int opcode) {
            // 首先,处理自己的代码逻辑
            if ((opcode >= IRETURN && opcode <= RETURN) || opcode == ATHROW) {
                super.visitFieldInsn(GETSTATIC, owner, getFieldName(methodName), "J"); // 注意,字段名字要对应
                super.visitMethodInsn(INVOKESTATIC, "java/lang/System", "currentTimeMillis", "()J", false);
                super.visitInsn(LADD);
                super.visitFieldInsn(PUTSTATIC, owner, getFieldName(methodName), "J"); // 注意,字段名字要对应
            }

            // 其次,调用父类的方法实现
            super.visitInsn(opcode);
        }
    }
}
7 + 30 = 37
19 - 26 = -7
27 + 36 = 63
8 + 5 = 13
42 - 10 = 32
27 + 17 = 44
16 - 40 = -24
44 + 23 = 67
14 + 29 = 43
20 - 27 = -7
timer_add = 1596
timer_sub = 974

总结

本文主要介绍了如何计算方法的运行时间,内容总结如下:

其实,遵循同样的思路,我们也可以计算方法的运行的总次数;再进一步,我们可以计算出方法每一次的平均执行时间。

上一篇下一篇

猜你喜欢

热点阅读