Attempt to invoke virtual method 'void org.tensorflow.lite.Interpreter.run(java.lang.Object, java.lang.Object)' on a null object reference

huangapple go评论76阅读模式
英文:

Attempt to invoke virtual method 'void org.tensorflow.lite.Interpreter.run(java.lang.Object, java.lang.Object)' on a null object reference

问题

我正在按照 TensorFlow 提供的文本分类演示示例在 Android Studio 上运行。然而,在运行应用程序时,在点击预测按钮后,应用程序崩溃并显示以下错误信息:

E/AndroidRuntime: FATAL EXCEPTION: main
Process: com.example.mltest, PID: 6318
java.lang.NullPointerException: Attempt to invoke virtual method 'void org.tensorflow.lite.Interpreter.run(java.lang.Object, java.lang.Object)' on a null object reference
    at com.example.mltest.TextClassificationClient.classify(TextClassificationClient.java:154)
    at com.example.mltest.MainActivity.lambda$classify$3$MainActivity(MainActivity.java:73)
    at com.example.mltest.-$$Lambda$MainActivity$iZpagZiqjnywt769FNidzy-9BHU.run(Unknown Source:4)
    at android.os.Handler.handleCallback(Handler.java:873)
    at android.os.Handler.dispatchMessage(Handler.java:99)
    at android.os.Looper.loop(Looper.java:193)
    at android.app.ActivityThread.main(ActivityThread.java:6669)
    at java.lang.reflect.Method.invoke(Native Method)
    at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:493)
    at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:858)

以下是 TextClassificationClient.java 文件的内容:

// TextClassificationClient.java 的内容
// ...

以下是 MainActivity.java 文件的内容:

// MainActivity.java 的内容
// ...

以下是你的 build.gradle 文件的内容:

// build.gradle 的内容
// ...

请注意,我已经提取了你提供的 Java 代码和 build.gradle 配置文件。如果你需要更多帮助,请随时提问。

英文:

I am following the text classification demo example given by tensorflow to run on Android Studio. However when running the app, after hitting the predict button, the app crashes with the following error.

E/AndroidRuntime: FATAL EXCEPTION: main
Process: com.example.mltest, PID: 6318
java.lang.NullPointerException: Attempt to invoke virtual method 'void org.tensorflow.lite.Interpreter.run(java.lang.Object, java.lang.Object)' on a null object reference
    at com.example.mltest.TextClassificationClient.classify(TextClassificationClient.java:154)
    at com.example.mltest.MainActivity.lambda$classify$3$MainActivity(MainActivity.java:73)
    at com.example.mltest.-$$Lambda$MainActivity$iZpagZiqjnywt769FNidzy-9BHU.run(Unknown Source:4)
    at android.os.Handler.handleCallback(Handler.java:873)
    at android.os.Handler.dispatchMessage(Handler.java:99)
    at android.os.Looper.loop(Looper.java:193)
    at android.app.ActivityThread.main(ActivityThread.java:6669)
    at java.lang.reflect.Method.invoke(Native Method)
    at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:493)
    at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:858)

Here is the TextClassificationClient java file.

package com.example.mltest;
public class TextClassificationClient {

private static final String TAG = "TextClassificationDemo";
private static final String MODEL_PATH = "text_classification.tflite";
private static final String DIC_PATH = "text_classification_vocab.txt";
private static final String LABEL_PATH = "text_classification_labels.txt";

private static final int SENTENCE_LEN = 256;
private static final String SIMPLE_SPACE_OR_PUNCTUATION = " |\\\\,|\\\\.|\\\\!|\\\\?|\\n";

private static final String START = "<START>";
private static final String PAD = "<PAD>";
private static final String UNKNOWN = "<UNKNOWN>";

private static final int MAX_RESULTS = 3;

private final Context context;
private final Map<String, Integer> dic = new HashMap<>();
private final List<String> labels = new ArrayList<>();
private Interpreter tflite;

public static class Result {

    private final String id;
    private final String title;
    private final Float confidence;

    public Result(String id, String title, Float confidence) {
        this.id = id;
        this.title = title;
        this.confidence = confidence;
    }

    public String getId() {
        return id;
    }

    public String getTitle() {
        return title;
    }

    public Float getConfidence() {
        return confidence;
    }

    @SuppressLint("DefaultLocale")
    @Override
    public String toString() {
        String resultString = "";

        if (id != null) {
            resultString += "[" + id + "] ";
        }

        if (title != null) {
            resultString += title + " ";
        }

        if (confidence != null) {
            resultString += String.format("(%.1f%%) ", confidence * 100.0f);
        }

        return resultString.trim();
    }
};

public TextClassificationClient(Context context) {
    this.context = context;
}

@WorkerThread
public void load() {
    loadModel();
    loadDictionary();
    loadLabels();
}

@WorkerThread
private synchronized void loadModel() {
    try {
        ByteBuffer buffer = loadModelFile(this.context.getAssets());
        tflite = new Interpreter(buffer);
        Log.v(TAG, "TFLite Model Loaded");

    } catch (IOException ex) {
        Log.v(TAG, ex.getMessage());
    }
}

@WorkerThread
private synchronized void loadDictionary() {
    try {
        loadDictionaryFile(this.context.getAssets());
        Log.v(TAG, "Dictionary Loaded");
    } catch (IOException ex) {
        Log.v(TAG, ex.getMessage());
    }
}

@WorkerThread
private synchronized void loadLabels() {
    try {
        loadLabelFile(this.context.getAssets());
        Log.v(TAG, "Labels Loaded");
    } catch (IOException ex) {
        Log.v(TAG, ex.getMessage());
    }
}

@WorkerThread
private synchronized void unload(){
    tflite.close();
    dic.clear();
    labels.clear();
}

@WorkerThread
public synchronized List<Result> classify(String text) {
    float[][] input = tokenizeInputText(text);

    Log.v(TAG, "Classifying with TFLite");

    float[][] output = new float[1][labels.size()];
    System.out.println("input inside classify in textclient" + Arrays.deepToString(input) + " and labels size is " + labels.size());
    System.out.println("Out put is " + Arrays.deepToString(output));
    tflite.run(input, output);

    PriorityQueue<Result> pq = new PriorityQueue<>(
            MAX_RESULTS, (lhs, rhs) -> Float.compare(rhs.getConfidence(), lhs.getConfidence()));
    for(int i = 0; i < labels.size(); i++) {
        pq.add(new Result("" + i, labels.get(i), output[0][i]));
    }

    final ArrayList<Result> results = new ArrayList<>();
    while (!pq.isEmpty()){
        results.add(pq.poll());
    }

    return results;
}

private static MappedByteBuffer loadModelFile(AssetManager assetManager) throws IOException {

    try(AssetFileDescriptor fileDescriptor = assetManager.openFd(MODEL_PATH);
    FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) {
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }
}

private void loadLabelFile(AssetManager assetManager) throws IOException{
    try (InputStream ins = assetManager.open(LABEL_PATH);
         BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(ins))){
        while (bufferedReader.ready()) {
            labels.add(bufferedReader.readLine());
        }
    }
}

private void loadDictionaryFile(AssetManager assetManager) throws IOException{
    try (InputStream ins = assetManager.open(DIC_PATH);
            BufferedReader reader = new BufferedReader(new InputStreamReader(ins))){
        while (reader.ready()){
            List<String> line = Arrays.asList(reader.readLine().split(" "));
            if (line.size() < 2){
                continue;
            }

            dic.put(line.get(0), Integer.parseInt(line.get(1)));
        }
    }
}

float[][] tokenizeInputText(String text) {

    float[] tmp = new float[SENTENCE_LEN];
    List<String> array = Arrays.asList(text.split(SIMPLE_SPACE_OR_PUNCTUATION));

    int index = 0;
    // Prepend <START> if it is in vocabulary file.
    if (dic.containsKey(START)) {
        tmp[index++] = dic.get(START);
    }

    for (String word : array) {
        if (index >= SENTENCE_LEN) {
            break;
        }
        tmp[index++] = dic.containsKey(word) ? dic.get(word) : (int) dic.get(UNKNOWN);
    }
    // Padding and wrapping.
    Arrays.fill(tmp, index, SENTENCE_LEN - 1, (int) dic.get(PAD));
    float[][] ans = {tmp};
    return ans;
}

Map<String, Integer> getDic() {
    return this.dic;
}

Interpreter getTflite() {
    return this.tflite;
}

List<String> getLabels(){
    return this.labels;
}
}

And the MainActivity java file.

public class MainActivity extends AppCompatActivity {

private static final String TAG = "TextClassificationDemo";
private TextClassificationClient client;

private TextView resultTextView;
private EditText inputEditText;
private Handler handler;
private ScrollView scrollView;

@Override
protected void onCreate(Bundle savedInstanceState) {
    super.onCreate(savedInstanceState);
    setContentView(R.layout.activity_main);
    Log.v(TAG, "On Create");

    client = new TextClassificationClient(getApplicationContext());
    handler = new Handler();
    Button classifyButton = findViewById(R.id.button);

    classifyButton.setOnClickListener( (View v) -> {
        classify(inputEditText.getText().toString());
    });

    resultTextView = findViewById(R.id.result_text_view);
    inputEditText = findViewById(R.id.input_text);
    scrollView = findViewById(R.id.scroll_view);
}

@Override
protected void onStart(){
    super.onStart();
    Log.v(TAG, "OnStart");
    handler.post(
            () -> {
                client.load();
            }
    );
}

@Override
protected void onStop(){
    super.onStop();
    Log.v(TAG, "OnStop");
    handler.post(
            () -> {
                client.load();
            }
    );
}

private void classify(final String text) {

    System.out.println("Text inside classify of Main Activity " + text);
    handler.post(
            () -> {
                List<TextClassificationClient.Result> results = client.classify(text);

                showResult(text, results);
            }
    );
}

private void showResult(final String inputText, final List<TextClassificationClient.Result> results){
    runOnUiThread(
            () -> {
                String textToShow = "Input : " + inputText + "\nOutput : \n";
                for (int i = 0; i < results.size(); i++) {
                    TextClassificationClient.Result result = results.get(i);
                    textToShow += String.format("    %s: %s\\n", result.getTitle(), result.getConfidence());
                }

                textToShow += "---------\\n";

                resultTextView.append(textToShow);
                inputEditText.getText().clear();

                scrollView.post(() -> scrollView.fullScroll(View.FOCUS_DOWN));
            }
    );
}
}

Here is my gradle file.

apply plugin: 'com.android.application'

android {
compileSdkVersion 28
buildToolsVersion "30.0.2"

defaultConfig {
    applicationId "com.example.mltest"
    minSdkVersion 28
    targetSdkVersion 28
    versionCode 1
    versionName "1.0"

    testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
}

buildTypes {
    release {
        minifyEnabled false
        proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
    }
}
compileOptions {
    sourceCompatibility JavaVersion.VERSION_1_8
    targetCompatibility JavaVersion.VERSION_1_8
}

aaptOptions {
    noCompress "tflite"
    noCompress "lite"
}
}

dependencies {
implementation fileTree(dir: "libs", include: ["*.jar"])
implementation 'androidx.appcompat:appcompat:1.2.0'
implementation 'androidx.constraintlayout:constraintlayout:2.0.1'
implementation 'org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly'
implementation 'org.tensorflow:tensorflow-lite-task-text:0.0.0-nightly'
implementation 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly'
testImplementation 'junit:junit:4.12'
androidTestImplementation 'androidx.test.ext:junit:1.1.2'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0'

}

I have followed other links at SO where the same issue was raised, but they haven't been of any help.
Please help me fix this problem. Thank you in advance!

答案1

得分: 3

已解决!tflite文件未正确添加到资产文件夹中。添加后顺利运行。

英文:

Resolved! The tflite file was not properly added to the assets folder. Ran smoothly after adding it

huangapple
  • 本文由 发表于 2020年9月12日 05:13:07
  • 转载请务必保留本文链接:https://go.coder-hub.com/63854367.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定