Module tfhelper.tflite.tflite

Expand source code
import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot


def predict_tflite_interpreter(interpreter, x_, predict_class=True):
    """
    Predict x_ with tflite interpreter

    Args:
        interpreter (tf.lite.Interpreter): TF Lite Interpreter
        x_ (np.ndarray): test data
        predict_class (bool): True: return argmax(result).
                              False: return as is
    Returns:
        np.array: Predicted Label or Values of top layer.
    """
    input_index = interpreter.get_input_details()[0]["index"]
    output_index = interpreter.get_output_details()[0]["index"]

    interpreter.set_tensor(input_index, np.expand_dims(x_, axis=0).astype(np.float32))
    interpreter.invoke()
    prediction_ = interpreter.get_tensor(output_index)

    return np.argmax(prediction_), prediction_ if predict_class else prediction_


def evaluate_tflite_interpreter(interpreter, x_test_, y_test_):
    """
    Evaluate tflite interpreter

    Args:
        interpreter (tf.lite.Interpreter): TF Lite Interpreter
        x_test_ (np.ndarray): test data
        y_test_ (np.ndarray): test label

    Returns:
        float: Accuracy
        np.ndarray: Prediction Results
    """
    prediction_result = []
    for x in x_test_:
        predict_label, _ = predict_tflite_interpreter(interpreter, x)
        prediction_result.append(predict_label)

    prediction_result = np.array(prediction_result)
    accuracy = np.sum(prediction_result == y_test_.flatten()) / y_test_.shape[0]

    return accuracy, prediction_result


def load_pruned_model(file_path, strip_model=True):
    """
    Load pruned TensorFlow Keras model
    :param file_path:
    :type file_path: str
    :param strip_model: True if the saved model is stripped.
    :type strip_model: bool
    :return:
    """
    model_ = tf.keras.models.load_model(file_path, custom_objects={
        'PruneLowMagnitude': tfmot.sparsity.keras.pruning_wrapper.PruneLowMagnitude})
    model_ = tfmot.sparsity.keras.strip_pruning(model_) if strip_model else model_

    return model_

Functions

def evaluate_tflite_interpreter(interpreter, x_test_, y_test_)

Evaluate tflite interpreter

Args

interpreter : tf.lite.Interpreter
TF Lite Interpreter
x_test_ : np.ndarray
test data
y_test_ : np.ndarray
test label

Returns

float
Accuracy
np.ndarray
Prediction Results
Expand source code
def evaluate_tflite_interpreter(interpreter, x_test_, y_test_):
    """
    Evaluate tflite interpreter

    Args:
        interpreter (tf.lite.Interpreter): TF Lite Interpreter
        x_test_ (np.ndarray): test data
        y_test_ (np.ndarray): test label

    Returns:
        float: Accuracy
        np.ndarray: Prediction Results
    """
    prediction_result = []
    for x in x_test_:
        predict_label, _ = predict_tflite_interpreter(interpreter, x)
        prediction_result.append(predict_label)

    prediction_result = np.array(prediction_result)
    accuracy = np.sum(prediction_result == y_test_.flatten()) / y_test_.shape[0]

    return accuracy, prediction_result
def load_pruned_model(file_path, strip_model=True)

Load pruned TensorFlow Keras model :param file_path: :type file_path: str :param strip_model: True if the saved model is stripped. :type strip_model: bool :return:

Expand source code
def load_pruned_model(file_path, strip_model=True):
    """
    Load pruned TensorFlow Keras model
    :param file_path:
    :type file_path: str
    :param strip_model: True if the saved model is stripped.
    :type strip_model: bool
    :return:
    """
    model_ = tf.keras.models.load_model(file_path, custom_objects={
        'PruneLowMagnitude': tfmot.sparsity.keras.pruning_wrapper.PruneLowMagnitude})
    model_ = tfmot.sparsity.keras.strip_pruning(model_) if strip_model else model_

    return model_
def predict_tflite_interpreter(interpreter, x_, predict_class=True)

Predict x_ with tflite interpreter

Args

interpreter : tf.lite.Interpreter
TF Lite Interpreter
x_ : np.ndarray
test data
predict_class : bool
True: return argmax(result). False: return as is

Returns

np.array
Predicted Label or Values of top layer.
Expand source code
def predict_tflite_interpreter(interpreter, x_, predict_class=True):
    """
    Predict x_ with tflite interpreter

    Args:
        interpreter (tf.lite.Interpreter): TF Lite Interpreter
        x_ (np.ndarray): test data
        predict_class (bool): True: return argmax(result).
                              False: return as is
    Returns:
        np.array: Predicted Label or Values of top layer.
    """
    input_index = interpreter.get_input_details()[0]["index"]
    output_index = interpreter.get_output_details()[0]["index"]

    interpreter.set_tensor(input_index, np.expand_dims(x_, axis=0).astype(np.float32))
    interpreter.invoke()
    prediction_ = interpreter.get_tensor(output_index)

    return np.argmax(prediction_), prediction_ if predict_class else prediction_