用android studio 开发tensorflow lit

2018-12-04  本文已影响0人  四月是你的谎言_6b55

准备tflite模型

在源码目录下新建asserts目录,将model.tflite, labels.txt文件拷贝到asserts目录下

配置build.gradle

要使用tensorflow lite需要导入对应的库,这里通过修改build.gradle来实现:

在dependencies下增加'org.tensorflow:tensorflow-lite:+'

dependencies {
    implementation fileTree(dir: 'libs', include: ['*.jar'])
    implementation 'com.android.support:appcompat-v7:28.0.0'
    implementation 'com.android.support.constraint:constraint-layout:1.1.3'
    testImplementation 'junit:junit:4.12'
    androidTestImplementation 'com.android.support.test:runner:1.0.2'
    androidTestImplementation 'com.android.support.test.espresso:espresso-core:3.0.2'
    implementation 'org.tensorflow:tensorflow-lite:+'    
}

在android下增加 aaptOptions

android {
    compileSdkVersion 28
    defaultConfig {
        applicationId "com.example.test.voicerecognition"
        minSdkVersion 26
        targetSdkVersion 28
        versionCode 1
        versionName "1.0"
        testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
    }
    buildTypes {
        release {
            minifyEnabled false
            proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
        }
    }
    aaptOptions {
        noCompress "tflite"
    }
}

然后resync gradle就可以使用了

java代码中使用tensorflow lite

1 导入库

import org.tensorflow.lite.Interpreter;

2 实例化Interpreter对象, 处理数据,喂给模型跑起来,获得结果

private Interpreter tfLite;
...
try {
            c.tfLite = new Interpreter(loadModelFile(assetManager, modelFilename));
      } catch (Exception e) {
            throw new RuntimeException(e);
      }

3 加载模型

    /**
     * Memory-map the model file in Assets.
     */
    private static MappedByteBuffer loadModelFile(AssetManager assets, String modelFilename)
            throws IOException {
        AssetFileDescriptor fileDescriptor = assets.openFd(modelFilename);
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }

4 准备数据, 运行模型,获取模型预测结果

tfLite.run(imgData, labelProb);
for (int i = 0; i < labels.size(); ++i) {
      pq.add(
                    new Recognition(
                            "" + i,
                            labels.size() > i ? labels.get(i) : "unknown",
                            (float) labelProb[0][i],
                            null));
 }

参考文献

https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/java/demo

上一篇下一篇

猜你喜欢

热点阅读