Module tfhelper.tflite.converter
Expand source code
import tensorflow as tf
def keras_model_to_tflite(model, config):
"""
Convert Keras model to tflite model
Args:
model (tf.keras.models.Model): TensorFlow Model
config (dict): configuration info Ex)
{
"quantization": false,
"quantization_type": "int8", # ["int8", "float16", "float32"]
"tf_ops": false,
"exp_converter": false,
"out_path": "/writing/tflite/model/path"
}
Returns:
The converted tflite model data in serialized format.
"""
converter = tf.lite.TFLiteConverter.from_keras_model(model)
config = parse_config(config)
optimizations = []
supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
supported_types = []
if config['quantization']:
optimizations += [tf.lite.Optimize.DEFAULT]
supported_types += [config['quantization_type']]
if config['quantization_type'] == tf.uint8:
supported_ops += tf.lite.OpsSet.TFLITE_BUILTINS_INT8
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
if config['tf_ops']:
supported_ops += [tf.lite.OpsSet.SELECT_TF_OPS]
converter.optimizations = optimizations
converter.target_spec.supported_ops = supported_ops
converter.target_spec.supported_types = supported_types
converter.experimental_new_converter = config['exp_converter']
tflite_model = converter.convert()
with open(config['out_path'], 'wb') as f:
f.write(tflite_model)
print('Saved TFLite model to:', config['out_path'])
return tflite_model
def parse_config(config):
"""
Parse config dict in keras_model_to_tflite
Converts data type written as str to tf dtypes
'float32' -> tf.float32
'float16' -> tf.float16
'int8' -> tf.int8
Args:
config (dict): tflite config dict from keras_model_to_tflite
Returns:
dict: Converted config
"""
qtype = config['quantization_type']
config['quantization_type'] = tf.float16 if qtype == "float16" else tf.int8 if qtype == "int8" else tf.float32
return config
Functions
def keras_model_to_tflite(model, config)
-
Convert Keras model to tflite model
Args
model
:tf.keras.models.Model
- TensorFlow Model
config
:dict
- configuration info Ex) { "quantization": false, "quantization_type": "int8", # ["int8", "float16", "float32"] "tf_ops": false, "exp_converter": false, "out_path": "/writing/tflite/model/path" }
Returns
The converted tflite model data in serialized format.
Expand source code
def keras_model_to_tflite(model, config): """ Convert Keras model to tflite model Args: model (tf.keras.models.Model): TensorFlow Model config (dict): configuration info Ex) { "quantization": false, "quantization_type": "int8", # ["int8", "float16", "float32"] "tf_ops": false, "exp_converter": false, "out_path": "/writing/tflite/model/path" } Returns: The converted tflite model data in serialized format. """ converter = tf.lite.TFLiteConverter.from_keras_model(model) config = parse_config(config) optimizations = [] supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] supported_types = [] if config['quantization']: optimizations += [tf.lite.Optimize.DEFAULT] supported_types += [config['quantization_type']] if config['quantization_type'] == tf.uint8: supported_ops += tf.lite.OpsSet.TFLITE_BUILTINS_INT8 converter.inference_input_type = tf.uint8 converter.inference_output_type = tf.uint8 if config['tf_ops']: supported_ops += [tf.lite.OpsSet.SELECT_TF_OPS] converter.optimizations = optimizations converter.target_spec.supported_ops = supported_ops converter.target_spec.supported_types = supported_types converter.experimental_new_converter = config['exp_converter'] tflite_model = converter.convert() with open(config['out_path'], 'wb') as f: f.write(tflite_model) print('Saved TFLite model to:', config['out_path']) return tflite_model
def parse_config(config)
-
Parse config dict in keras_model_to_tflite Converts data type written as str to tf dtypes 'float32' -> tf.float32 'float16' -> tf.float16 'int8' -> tf.int8
Args
config
:dict
- tflite config dict from keras_model_to_tflite
Returns
dict
- Converted config
Expand source code
def parse_config(config): """ Parse config dict in keras_model_to_tflite Converts data type written as str to tf dtypes 'float32' -> tf.float32 'float16' -> tf.float16 'int8' -> tf.int8 Args: config (dict): tflite config dict from keras_model_to_tflite Returns: dict: Converted config """ qtype = config['quantization_type'] config['quantization_type'] = tf.float16 if qtype == "float16" else tf.int8 if qtype == "int8" else tf.float32 return config