您的位置:首页 > 娱乐 > 八卦 > 深圳龙华区天气预报_东莞专业网站推广方式_全国最新的疫情数据_南宁seo网络推广

深圳龙华区天气预报_东莞专业网站推广方式_全国最新的疫情数据_南宁seo网络推广

2024/10/6 13:11:24 来源:https://blog.csdn.net/model2005/article/details/142717334  浏览:    关键词:深圳龙华区天气预报_东莞专业网站推广方式_全国最新的疫情数据_南宁seo网络推广
深圳龙华区天气预报_东莞专业网站推广方式_全国最新的疫情数据_南宁seo网络推广

APP开发

build.gradle导入库

    //implementation 'org.tensorflow:tensorflow-android:+'
    implementation 'org.tensorflow:tensorflow-lite:2.4.0'
    implementation 'org.tensorflow:tensorflow-lite-support:0.3.1
    implementation 'org.tensorflow:tensorflow-lite-metadata:0.3.1'

加载模型

      try {
            tfLiteClassificationUtil = new TFLiteClassificationUtil(CONST.downPath + "/zjym.tflite");
            Toast.makeText(MainActTflite.this, "模型加载成功!", Toast.LENGTH_SHORT).show();
        } catch (Exception e) {
            Toast.makeText(MainActTflite.this, "模型加载失败!", Toast.LENGTH_SHORT).show();
            e.printStackTrace();
            finish();
        }

模型一般在assets目录下,在编译时会集成到APP中,不利于模型的迭代,这里模型保存在内部存储目录下。

分类预测

                try {      // 预测图像
                    FileInputStream fis = new FileInputStream(image_path);
                    imageView.setImageBitmap(BitmapFactory.decodeStream(fis));
                    long start = System.currentTimeMillis();
                    int[][] res2Arr = tfLiteClassificationUtil.predictImage(image_path);
                    long end = System.currentTimeMillis();
                    String show_text = "预测结果标签:" + (int) res2Arr[res2Arr.length-1][0] +
                            "\n名称:" +  classNames.get((int) res2Arr[res2Arr.length-1][0]) +"概率:" + (float) res2Arr[res2Arr.length - 1][1] / 256 +
                            "\n名称:" +  classNames.get((int) res2Arr[res2Arr.length-2][0]) +"概率:" + (float) res2Arr[res2Arr.length - 2][1] / 256 +
                            "\n名称:" +  classNames.get((int) res2Arr[res2Arr.length-3][0]) +"概率:" + (float) res2Arr[res2Arr.length - 3][1] / 256 +
                            "\n时间:" + (end - start) + "ms";
                    textView.setText(show_text);
                } catch (Exception e) {
                    e.printStackTrace();
                }

res2Arr[res2Arr.length - 1][1] / 256,两个整数相除显示为0,添加(float)显示字符串

TFLiteClassificationUtil类功能模块

public TFLiteClassificationUtil(String modelPath) throws Exception {

        File file = new File(modelPath);
        if (!file.exists()) {
            throw new Exception("model file is not exists!");
        }

        try {
            Interpreter.Options options = new Interpreter.Options();

            options.setNumThreads(NUM_THREADS);// 使用多线程预测
            NnApiDelegate delegate = new NnApiDelegate();// 使用Android自带的API或者GPU加速
//            GpuDelegate delegate = new GpuDelegate();
            options.addDelegate(delegate);
            tflite = new Interpreter(file, options);
            // 获取输入,shape为{1, height, width, 3}
            int[] imageShape = tflite.getInputTensor(tflite.getInputIndex("input_1")).shape();
            DataType imageDataType = tflite.getInputTensor(tflite.getInputIndex("input_1")).dataType();
            inputImageBuffer = new TensorImage(imageDataType);
            // 获取输入,shape为{1, NUM_CLASSES}
            int[] probabilityShape = tflite.getOutputTensor(tflite.getOutputIndex("Identity")).shape();
            DataType probabilityDataType = tflite.getOutputTensor(tflite.getOutputIndex("Identity")).dataType();
            //outputProbabilityBuffer = TensorBuffer.createFixedSize(probabilityShape, probabilityDataType);
            outputProbabilityBuffer = TensorBuffer.createFixedSize(tflite.getOutputTensor(0).shape(), DataType.UINT8);

            // 添加图像预处理方式
            imageProcessor = new ImageProcessor.Builder()
                    .add(new ResizeOp(224, 224, ResizeOp.ResizeMethod.NEAREST_NEIGHBOR))
                    .add(new NormalizeOp(new float[] {0.0f}, new float[] {255.0f}))
                    .add(new QuantizeOp(0f, 0.003921569f))
                    .add(new CastOp(DataType.UINT8))
                    .build();

            TensorProcessor probabilityPostProcessor = new TensorProcessor.Builder()
                    .add(new DequantizeOp((float) 0, (float) 0.00390625))
                    .add(new NormalizeOp(new float[]{0.0f}, new float[]{1.0f}))
                    .build();
        } catch (Exception e) {
            e.printStackTrace();
            throw new Exception("load model fail!");
        }
    }

    public int[][] predictImage(String image_path) throws Exception {
        if (!new File(image_path).exists()) {
            throw new Exception("image file is not exists!");
        }
        FileInputStream fis = new FileInputStream(image_path);
        Bitmap bitmap = BitmapFactory.decodeStream(fis);
        int[][] result = predictImage(bitmap);
        if (bitmap.isRecycled()) {
            bitmap.recycle();
        }
        return result;
    }

    // 重载方法,直接使用Bitmap预测
    public int[][] predictImage(Bitmap bitmap) throws Exception {
        return predict(bitmap);
    }

    private int[][] predict(Bitmap bmp) throws Exception {
        inputImageBuffer = loadImage(bmp);

        try {
            tflite.run(inputImageBuffer.getBuffer(), outputProbabilityBuffer.getBuffer().rewind());
        } catch (Exception e) {
            throw new Exception("predict image fail! log:" + e);
        }
        int[] results = outputProbabilityBuffer.getIntArray();
        Log.d("results", Arrays.toString(results));
        int[][] arr = new int[results.length][2];
        for (int i=0;i<results.length;i++) {
            arr[i][0] = i;
            arr[i][1] = results[i];
        }
        Arrays.sort(arr, Comparator.comparingInt(e -> e[1]));
        //int l = getMaxResult(results);
        return arr;//new float[]{l, results[l]};
    }

tflite默认保存格式为UINT8,如果不加add(new CastOp(DataType.UINT8))可能显示

Cannot copy to a TensorFlowLite tensor (input_1) with 150528 bytes from a Java Buffer with 602112 bytes

默认的预训练模型是 EfficientNet-Lite0,如果为其他模型,其输入参数等也要修改。可通过下述方法查看。

Android Studio ->File ->open ->other ->tflite,打开tflite模型,build ->Make Project 会自动生成模型接口类,并移动模型到ml目录,查看类中模型参数。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com