一个案例搞定策略模式

2020-08-29  本文已影响0人  HurryYu_YZH

提到设计模式,只要是有过开发经验的开发人员都或多或少听过&用过设计模式,比如我们都能信手拈来的「单例模式」、「观察者模式」等等。当然也有我们平时不太常用,但众多优秀的开源框架中广泛使用的设计模式,例如著名的网络框架retrofit使用的「代理模式」、okhttp使用的「责任链模式」。

关于设计模式的文章,网上一搜一大堆,各位前辈都总结得非常好。可以说前人的技术分享大大降低了后人的学习门槛,使中国互联网整体技术水平成指数上升,感谢每一位热爱分享的Coder!

很早以前我就准备写一篇介绍策略模式的文章,但是始终没有一个较好的例子。在最近的项目中,我再次用到了策略模式,于是我决定将其作为本文讲解策略模式的案例。

本文会先直接通过实际案例的形式逐步带入策略模式,最后再给出策略模式的完整定义,这样更容易理解。

为了不偏离主题,提升阅读体验,本文所有代码都经过精简处理。

一、案例前戏

公司参加了某人工智能比赛,AI部门的同事使用TensorFlow训练了一个能根据呼吸音推测患上“肺气肿”概率的模型。需要在Android设备上通过APP + 听诊器完成呼吸音的采集,然后通过模型给出结论。

由于时间紧迫,不知同事从哪里搞来了一个半成品项目,该项目已经实现了呼吸音的采集功能,将音频保存为.wav文件。而我需要做的,就是将采集到的呼吸音交给模型,得出结论。

而TensorFlow模型在Android上是没办法直接使用的,必须要将TensorFlow模型转换为TensorFlow Lite才可以在Android上使用(反正又不是我转o( ̄︶ ̄)o)。

二、案例中期

趁着AI部门的同事还没有将TensorFlow转为TensorFlow Lite之前,我去TensorFlow Lite官网看了看使用文档,一切都是如此美妙,仅需3步就可以搞定一切:

  1. 加载并初始化模型文件
  2. 调用run方法传入inputObject和outputObject
  3. 得到结果

在这个项目中,inputObject其实就是.wav文件的字节数组,由于模型的返回结果是JSON,因此outputObject其实就是String。现在我就放心的摸鱼去了。

摸鱼的时光总是短暂的,不一会儿,同事就给了我TensorFlow Lite的模型文件,后缀是.tflite。按照官方文档的要求,我将模型文件audio.tflite放到了assets目录中,然后编写了如下代码:

private void stopRecord() {
    // 省略其它代码,这里已经获取到了wav录音的File对象
    File wavFile = ...;
    aiCheckLocal(wavFile);
}

private void aiCheckLocal(File file) {
    ByteBuffer buffer = loadModelFile(getAssets());
    // 初始化模型
    Interpreter tfLite = new Interpreter(buffer);
    byte[] inputBytes = FileUtils.fileToByteArray(file);
    String outputStr = "";
    // 调用模型
    tfLite.run(inputBytes, outputStr);
    // 输出模型的结论
    Log.d(TAG, outputStr);
}

private MappedByteBuffer loadModelFile(AssetManager assetManager) {
    // 读取assets目录下tflite模型的代码,无需关心
}

aiCheckLocal方法就是我编写的根据手机采集到的声音数据,利用TensorFlow Lite进行结果推测。大家在看代码的时候,不需要关注具体的细节。

呵呵,不到10分钟就搞定了。一运行,???:

Internal error: Unexpected failure when preparing tensor allocations: Encountered unresolved custom op: Switch.
    Node number 10 (Switch) failed to prepare.

最终查明这个错是因为AI部门同事给我的模型有问题,加载不了。短时间内他们也没办法解决,于是他们提出了让我调用接口完成音频数据的推测,也就是将这个过程变为了在线而不是本地模型推测了。同时,还需要我保留本地模型推测的代码,以便后期他们修复模型问题后,还是可以切换为本地模型推测。

这有什么难的,再加个在线推测的方法不就行了:

private void stopRecord() {
    // 省略其它代码,这里已经获取到了wav录音的File对象
    File wavFile = ...;
    aiCheckOnline(wavFile);
}

private void aiCheckOnline(File file) {
    // 通过网络推测代码
    OkHttpClient client = new OkHttpClient();
    MediaType mediaType = MediaType.Companion.parse("multipart/form-data");
    RequestBody fileBody = RequestBody.Companion.create(file, mediaType);

    RequestBody requestBody = new MultipartBody.Builder()
        .setType(MultipartBody.FORM)
        .addFormDataPart("sound", file.getName(), fileBody)
        .build();

    Request request = new Request.Builder()
        .url(REQUEST_RUL)
        .post(requestBody)
        .build();
    
    client.newCall(request).enqueue(new Callback() {
        @Override
        public void onFailure(@NotNull Call call, @NotNull IOException e) {
        }

        @Override
        public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
            String jsonStr = response.body().string();
            Log.d(TAG, jsonStr);
        }
    });
}

private void aiCheckLocal(File file) {
    // 本地tflite推测代码 省略
}

我又新增了一个aiCheckOnline的方法用于在线推测。这样,当使用在线推测的时候,调用aiCheckOnline,使用本地模型推测的时候,调用aiCheckLocal即可。但是我们回过头来想想,功能虽然是实现了,但这样真的好吗?现在的代码有如下问题:

也许你觉得这些都不是问题,因为你已经有解决方案了:对于第一个问题,我们把这些方法写在一个单独的类中不就行了;对于第二个问题,我们在类中对外界提供一个统一个调用方法,在这个方法内部进行判断到底是需要何种实现方案;对于第三个问题,由于我们提供了统一的调用方法,因此这个问题也就不存在了。所以,我们的代码可以改成这样:

public class AiCheck {
    /**
     * 执行推测时的方式
     */
    private int type;

    /**
     * 方式1 本地模型推测
     */
    public static final int TYPE_LOCAL = 1;

    /**
     * 方式二 网络接口推测
     */
    public static final int TYPE_NETWORK = 2;

    public AiCheck(int type) {
        this.type = type;
    }

    public void check(File file) {
        // 根据方式,执行对应的方法
        if (type == TYPE_LOCAL) {
            aiCheckLocal(file);
        } else if (type == TYPE_NETWORK) {
            aiCheckOnline(file);
        }
    }

    private void aiCheckOnline(File file) {
        // 通过网络推测代码
        OkHttpClient client = new OkHttpClient();
        MediaType mediaType = MediaType.Companion.parse("multipart/form-data");
        RequestBody fileBody = RequestBody.Companion.create(file, mediaType);

        RequestBody requestBody = new MultipartBody.Builder()
                .setType(MultipartBody.FORM)
                .addFormDataPart("sound", file.getName(), fileBody)
                .build();

        Request request = new Request.Builder()
                .url(REQUEST_RUL)
                .post(requestBody)
                .build();

        client.newCall(request).enqueue(new Callback() {
            @Override
            public void onFailure(@NotNull Call call, @NotNull IOException e) {
            }

            @Override
            public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
                String jsonStr = response.body().string();
                Log.d(TAG, jsonStr);
            }
        });
    }

    private void aiCheckLocal(File file) {
        ByteBuffer buffer = loadModelFile(getAssets());
        // 初始化模型
        Interpreter tfLite = new Interpreter(buffer);
        byte[] inputBytes = FileUtils.fileToByteArray(file);
        String outputStr = "";
        // 调用模型
        tfLite.run(inputBytes, outputStr);
        // 输出模型的结论
        Log.d(TAG, outputStr);
    }

    private MappedByteBuffer loadModelFile(AssetManager assetManager) {
        // 读取assets目录下tflite模型的代码,无需关心
    }
}

这样,我们的调用就可以变为这样:

private AiCheck aiCheck = new AiCheck(AiCheck.TYPE_NETWORK);

private void stopRecord() {
    // 省略其它代码,这里已经获取到了wav录音的File对象
    File wavFile = ...;
    aiCheck.check(wavFile);
}

调用就变得如此清爽了。但是这样真的就没有任何问题了吗?其实还是有的:

这...怎么办?

三、案例高潮

到了这个时间点,必须要放出大招了——策略模式。我们知道设计模式分为了三类:创建型、行为型、结构型,而策略模式属于行为型。先不谈其定义,我们来看看如何使用策略模式改进当前的问题。

首先,对于我们要实现的这个功能,行为只有一个,那就是音频推测,因此我们可以将这个行为抽象成一个接口:

/**
 * 行为接口
 */
public interface IAudioCheckBehavior {
    void check(File wavFile);
}

接着,在当前的情况下,针对这个行为,我们需要两种实现方案:本地模型推测 和 在线推测。我们分别写两个类来实现这两种方案:

这两种方案都实现了IAudioCheckBehavior接口,并各自用自己的方式实现了接口中的check方法。如果后期要增加新的实现方案呢?我们可以再定义一个类,并实现IAudioCheckBehavior接口就行了,不会在原有类中去添加或修改代码。

最后,我们还需要定义一个Context类,这个类的作用是设置一个具体的策略供外界使用:

public class AudioCheckContext {
    private IAudioCheckBehavior behavior;

    public AudioCheckContext() {
        this.behavior = new OnlineAudioCheck();
    }

    public void setAudioCheckBehavior(IAudioCheckBehavior behavior) {
        this.behavior = behavior;
    }

    public void check(File wavFile) {
        behavior.check(wavFile);
    }
}

注意,这个behaviorIAudioCheckBehavior类型的,可赋值为它的任意子类,比如OnlineAudioCheck或是LocalAudioCheck,这其实就是多态的思想。如果外界没有明确指定,则behavior的默认实现是OnlineAudioCheck,当然外界也可手动指定behavior的实现类。

没错,这就是策略模式在本案例中的完整实现了,下面用起来就很爽了:

private AudioCheckContext audioCheckContext = new AudioCheckContext();

private void stopRecord() {
    // 省略其它代码,这里已经获取到了wav录音的File对象
    File wavFile = ...;
    audioCheckContext.check(wavFile);
}

首先创建了AudioCheckContext对象,然后调用了AudioCheckContextcheck方法,在check方法内部,会去调用IAudioCheckBehaviorcheck方法,由于我们没有在外界设置IAudioCheckBehavior,因此它的实现类默认是OnlineAudioCheck,转而就会去走在线推测音频的逻辑。

真爽,如果哪天AI部门的同事把模型搞定了,要求我改为本地模型推测,我只需新增或修改一行代码:

如果我发现本地推测的实现代码有问题,直接去对应的LocalAudioCheck类修复就好了,不会影响到其它功能。

四、策略模式的定义

好了,现在是时候来看看策略模式的定义了:

定义了算法族,分别封装起来,让它们之间可以互相替换,此模式让算法的变化独立于使用算法的客户

有了上面的案例,理解这个定义就容易多了。

总结一下策略模式的核心思想:将变化的部分独立出来,将它们单独实现成算法类,并且这些算法是可以相互替换且对调用者隐藏实现细节。

上一篇 下一篇

猜你喜欢

热点阅读