如何在Tensorflow代码中使用Tensorflow Lite模型?

huangapple go评论60阅读模式
英文:

How can I use a Tensorflow lite model in a Tensorflow Code?

问题

# Import packages
import os
import cv2
import numpy as np
import tensorflow as tf
import argparse

# Set up camera constants
IM_WIDTH = 640
IM_HEIGHT = 480

# Select camera type (if user enters --usbcam when calling this script,
# a USB webcam will be used)

parser = argparse.ArgumentParser()
parser.add_argument('--usbcam', help='Use a USB webcam instead of picamera',
                    action='store_true')
args = parser.parse_args()
if args.usbcam:
    camera_type = 'usb'

#### Initialize TensorFlow model ####

# This is needed since the working directory is the object_detection folder.
sys.path.append('..')

# Import utilities
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util

# Name of the directory containing the object detection module you're using
MODEL_NAME = 'ssd_inception_v2_coco_2017_11_17'

# Grab path to the current working directory
CWD_PATH = os.getcwd()

# Path to the frozen detection graph .pb file, which contains the model used for object detection
PATH_TO_CKPT = os.path.join(CWD_PATH, MODEL_NAME, 'frozen_inference_graph.pb')

# Path to label map file
PATH_TO_LABELS = os.path.join(CWD_PATH, 'data', 'mscoco_label_map.pbtxt')

# Number of classes the object detector can identify
NUM_CLASSES = 90

## Load the label map.
# Label maps map indices to category names, so when the convolution
# network predicts '5', we know it corresponds to 'airplane'.
# Here we use internal utility functions, but any dictionary mapping integers to appropriate string labels will work.
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)

# Load the TensorFlow model into memory.
detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')

    sess = tf.Session(graph=detection_graph)

# Define input and output tensors for the object detection classifier

# Input tensor is the image
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')

# Output tensors are the detection boxes, scores, and classes
detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')

#### Initialize other parameters ####

# Initialize frame rate calculation
frame_rate_calc = 1
freq = cv2.getTickFrequency()
font = cv2.FONT_HERSHEY_SIMPLEX

# Define inside box coordinates (top left and bottom right)
TL_inside = (int(IM_WIDTH * 0.016), int(IM_HEIGHT * 0.021))
BR_inside = (int(IM_WIDTH * 0.323), int(IM_HEIGHT * 0.979)

# Define outside box coordinates (top left and bottom right)
TL_outside = (int(IM_WIDTH * 0.333), int(IM_HEIGHT * 0.021))
BR_outside = (int(IM_WIDTH * 0.673), int(IM_HEIGHT * 0.979)

# Define right box coordinates (top left and bottom right)
TL_right = (int(IM_WIDTH * 0.683), int(IM_HEIGHT * 0.021))
BR_right = (int(IM_WIDTH * 0.986), int(IM_HEIGHT * 0.979)

# Initialize control variables used for pet detection
detected_inside = False
detected_outside = False
detected_right = False

inside_counter = 0
outside_counter = 0
right_counter = 0

pause = 0
pause_counter = 0

#### Pet detection function ####

# This function contains the code to detect a pet, determine if it's
# inside or outside, and send a text to the user's phone.
def pet_detector(frame):

    # Use global variables for the control variables so they retain their value after the function exits
    global detected_inside, detected_outside, detected_right
    global inside_counter, outside_counter, right_counter
    global pause, pause_counter

    frame_expanded = np.expand_dims(frame, axis=0)

    # Perform the actual detection by running the model with the image as input
    (boxes, scores, classes, num) = sess.run(
        [detection_boxes, detection_scores, detection_classes, num_detections],
        feed_dict={image_tensor: frame_expanded})

    # Draw the results of the detection (i.e., visualize the results)
    vis_util.visualize_boxes_and_labels_on_image_array(
        frame,
        np.squeeze(boxes),
        np.squeeze(classes).astype(np.int32),
        np.squeeze(scores),
        category_index,
        use_normalized_coordinates=True,
        line_thickness=8,
        min_score_thresh=0.40)

    # Draw boxes defining "outside" and "inside" locations
    cv2.rectangle(frame, TL_outside, BR_outside, (255, 20, 20), 3)
    cv2.putText(frame, "Outside box", (TL_outside[0] + 10, TL_outside[1] - 10), font, 1, (255, 20, 255), 3, cv2.LINE_AA)
    cv2.rectangle(frame, TL_inside, BR_inside, (20, 20, 255), 3)
    cv2.putText(frame, "Inside box", (TL_inside[0] + 10, TL_inside[1] - 10), font, 1, (20, 255, 255), 3, cv2.LINE_AA)
    cv2.rectangle(frame, TL_right, BR_right, (20, 255, 25), 3)
    cv2.putText(frame, "Right box", (TL_right[0] + 10, TL_right[1] - 10), font, 1, (20, 255, 255), 3, cv2.LINE_AA)

    # Check the class of the top detected object by looking at classes[0][0].
    # If the top detected object is a cat (1) or a dog (18), find its center coordinates by looking at the boxes[0][0] variable.
    if (int(classes[0][0]) == 1 or int(classes[0][0]) == 18) and pause == 0:
        x = int(((boxes[0][0][1] + boxes[0][0][3]) / 2) * IM_WIDTH)
        y = int(((boxes[0][0][0] + boxes[0][0][2]) / 2) * IM_HEIGHT)

        # Draw a circle at the center of the object
        cv2.circle(frame, (x, y), 5, (75

<details>
<summary>英文:</summary>

I am working on a personal project on Raspberry Pi 4 and I used Tensorflow. I have achieved around 1.39 fps and I wanted to convert to Tensorflow Lite to get more fps as well as utilize a Coral USB Accelerator and would like to know how can I use a Tensorflow lite model for this code.


Import packages

import os
import cv2
import numpy as np

import tensorflow as tf
import argparse
import sys
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

Set up camera constants

IM_WIDTH = 640
IM_HEIGHT = 480

Select camera type (if user enters --usbcam when calling this script,

a USB webcam will be used)

parser = argparse.ArgumentParser()
parser.add_argument('--usbcam', help='Use a USB webcam instead of picamera',
action='store_true')
args = parser.parse_args()
if args.usbcam:
camera_type = 'usb'

Initialize TensorFlow model

This is needed since the working directory is the object_detection folder.

sys.path.append('..')

Import utilites

from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util

Name of the directory containing the object detection module we're using

MODEL_NAME = 'ssd_inception_v2_coco_2017_11_17'

Grab path to current working directory

CWD_PATH = os.getcwd()

Path to frozen detection graph .pb file, which contains the model that is used

for object detection.

PATH_TO_CKPT = os.path.join(CWD_PATH,MODEL_NAME,'frozen_inference_graph.pb')

Path to label map file

PATH_TO_LABELS = os.path.join(CWD_PATH,'data','mscoco_label_map.pbtxt')

Number of classes the object detector can identify

NUM_CLASSES = 90

Load the label map.

Label maps map indices to category names, so that when the convolution

network predicts 5, we know that this corresponds to airplane.

Here we use internal utility functions, but anything that returns a

dictionary mapping integers to appropriate string labels would be fine

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)

Load the Tensorflow model into memory.

detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')

sess = tf.Session(graph=detection_graph)

Define input and output tensors (i.e. data) for the object detection classifier

Input tensor is the image

image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')

Output tensors are the detection boxes, scores, and classes

Each box represents a part of the image where a particular object was detected

detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')

Each score represents level of confidence for each of the objects.

The score is shown on the result image, together with the class label.

detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')

Number of objects detected

num_detections = detection_graph.get_tensor_by_name('num_detections:0')

Initialize other parameters

Initialize frame rate calculation

frame_rate_calc = 1
freq = cv2.getTickFrequency()
font = cv2.FONT_HERSHEY_SIMPLEX

Define inside box coordinates (top left and bottom right)

TL_inside = (int(IM_WIDTH0.016),int(IM_HEIGHT0.021))
BR_inside = (int(IM_WIDTH0.323),int(IM_HEIGHT0.979))

Define outside box coordinates (top left and bottom right)

TL_outside = (int(IM_WIDTH0.333),int(IM_HEIGHT0.021))
BR_outside = (int(IM_WIDTH0.673),int(IM_HEIGHT0.979))

Define outside box coordinates (top left and bottom right)

TL_right = (int(IM_WIDTH0.683),int(IM_HEIGHT0.021))
BR_right = (int(IM_WIDTH0.986),int(IM_HEIGHT0.979))

Initialize control variables used for pet detector

detected_inside = False
detected_outside = False
detected_right = False

inside_counter = 0
outside_counter = 0
right_counter = 0

pause = 0
pause_counter = 0

Pet detection function

This function contains the code to detect a pet, determine if it's

inside or outside, and send a text to the user's phone.

def pet_detector(frame):

# Use globals for the control variables so they retain their value after function exits
global detected_inside, detected_outside, detected_right
global inside_counter, outside_counter, right_counter
global pause, pause_counter
frame_expanded = np.expand_dims(frame, axis=0)
# Perform the actual detection by running the model with the image as input
(boxes, scores, classes, num) = sess.run(
[detection_boxes, detection_scores, detection_classes, num_detections],
feed_dict={image_tensor: frame_expanded})
# Draw the results of the detection (aka &#39;visulaize the results&#39;)
vis_util.visualize_boxes_and_labels_on_image_array(
frame,
np.squeeze(boxes),
np.squeeze(classes).astype(np.int32),
np.squeeze(scores),
category_index,
use_normalized_coordinates=True,
line_thickness=8,
min_score_thresh=0.40)
# Draw boxes defining &quot;outside&quot; and &quot;inside&quot; locations.
cv2.rectangle(frame,TL_outside,BR_outside,(255,20,20),3)
cv2.putText(frame,&quot;Outside box&quot;,(TL_outside[0]+10,TL_outside[1]-10),font,1,(255,20,255),3,cv2.LINE_AA)
cv2.rectangle(frame,TL_inside,BR_inside,(20,20,255),3)
cv2.putText(frame,&quot;Inside box&quot;,(TL_inside[0]+10,TL_inside[1]-10),font,1,(20,255,255),3,cv2.LINE_AA)
cv2.rectangle(frame,TL_right,BR_right,(20,255,25),3)
cv2.putText(frame,&quot;right box&quot;,(TL_right[0]+10,TL_right[1]-10),font,1,(20,255,255),3,cv2.LINE_AA)
# Check the class of the top detected object by looking at classes[0][0].
# If the top detected object is a cat (17) or a dog (18) (or a teddy bear (88) for test purposes),
# find its center coordinates by looking at the boxes[0][0] variable.
# boxes[0][0] variable holds coordinates of detected objects as (ymin, xmin, ymax, xmax)
if (((int(classes[0][0]) == 1) or (int(classes[0][0] == 18) or (int(classes[0][0]) == 88))) and (pause == 0)):
x = int(((boxes[0][0][1]+boxes[0][0][3])/2)*IM_WIDTH)
y = int(((boxes[0][0][0]+boxes[0][0][2])/2)*IM_HEIGHT)
# Draw a circle at center of object
cv2.circle(frame,(x,y), 5, (75,13,180), -1)
# If object is in inside box, increment inside counter variable
if ((x &gt; TL_inside[0]) and (x &lt; BR_inside[0]) and (y &gt; TL_inside[1]) and (y &lt; BR_inside[1])):
inside_counter = inside_counter + 1
# If object is in outside box, increment outside counter variable
if ((x &gt; TL_outside[0]) and (x &lt; BR_outside[0]) and (y &gt; TL_outside[1]) and (y &lt; BR_outside[1])):
outside_counter = outside_counter + 1
# If object is in outside box, increment outside counter variable
if ((x &gt; TL_right[0]) and (x &lt; BR_right[0]) and (y &gt; TL_right[1]) and (y &lt; BR_right[1])):
right_counter = right_counter + 1
# If pet has been detected inside for more than 10 frames, set detected_inside flag
# and send a text to the phone.
if inside_counter == 1:
detected_inside = True
inside_counter = 0
outside_counter = 0
right_counter = 0
# Pause pet detection by setting &quot;pause&quot; flag
pause = 1
# If pet has been detected outside for more than 10 frames, set detected_outside flag
# and send a text to the phone.
if outside_counter == 1:
detected_outside = True
inside_counter = 0
outside_counter = 0
right_counter = 0
# Pause pet detection by setting &quot;pause&quot; flag
pause = 1
# If pet has been detected outside for more than 10 frames, set detected_outside flag
# and send a text to the phone.
if right_counter == 1:
detected_right = True
inside_counter = 0
outside_counter = 0
right_counter = 0
# Pause pet detection by setting &quot;pause&quot; flag
pause = 1
# If pause flag is set, draw message on screen.
if pause == 1:
if detected_inside == True:
cv2.putText(frame,&#39;Left detected!&#39;,(int(IM_WIDTH*0.027),int(IM_HEIGHT-60)),font,3,(0,0,0),7,cv2.LINE_AA)
cv2.putText(frame,&#39;Left detected!&#39;,(int(IM_WIDTH*0.967),int(IM_HEIGHT-60)),font,3,(95,176,23),5,cv2.LINE_AA)
if detected_outside == True:
cv2.putText(frame,&#39;Mid detected!&#39;,(int(IM_WIDTH*0.027),int(IM_HEIGHT-60)),font,3,(0,0,0),7,cv2.LINE_AA)
cv2.putText(frame,&#39;Mid detected!&#39;,(int(IM_WIDTH*0.967),int(IM_HEIGHT-60)),font,3,(95,176,23),5,cv2.LINE_AA)
if detected_right == True:
cv2.putText(frame,&#39;Right detected!&#39;,(int(IM_WIDTH*0.027),int(IM_HEIGHT-60)),font,3,(0,0,0),7,cv2.LINE_AA)
cv2.putText(frame,&#39;Right detected!&#39;,(int(IM_WIDTH*0.967),int(IM_HEIGHT-60)),font,3,(95,176,23),5,cv2.LINE_AA)
# Increment pause counter until it reaches 30 (for a framerate of 1.5 FPS, this is about 20 seconds),
# then unpause the application (set pause flag to 0).
pause_counter = pause_counter + 1
if pause_counter &gt; 3:
pause = 0
pause_counter = 0
detected_inside = False
detected_outside = False
detected_right = False
# Draw counter info
cv2.putText(frame,&#39;Detection counter: &#39; + str(max(inside_counter,outside_counter, right_counter)),(10,100),font,0.5,(255,255,0),1,cv2.LINE_AA)
cv2.putText(frame,&#39;Pause counter: &#39; + str(pause_counter),(10,150),font,0.5,(255,255,0),1,cv2.LINE_AA)
return frame

Initialize camera and perform object detection

The camera has to be set up and used differently depending on if it's a

Picamera or USB webcam.

USB webcam

# Initialize USB webcam feed

camera = cv2.VideoCapture(0)
ret = camera.set(3,IM_WIDTH)
ret = camera.set(4,IM_HEIGHT)

# Continuously capture frames and perform object detection on them

while(True):

    t1 = cv2.getTickCount()
# Acquire frame and expand frame dimensions to have shape: [1, None, None, 3]
# i.e. a single-column array, where each item in the column has the pixel RGB value
ret, frame = camera.read()
# Pass frame into pet detection function
frame = pet_detector(frame)
# Draw FPS
cv2.putText(frame,&quot;FPS: {0:.2f}&quot;.format(frame_rate_calc),(30,50),font,1,(255,255,0),2,cv2.LINE_AA)
# All the results have been drawn on the frame, so it&#39;s time to display it.
cv2.imshow(&#39;Object detector&#39;, frame)
# FPS calculation
t2 = cv2.getTickCount()
time1 = (t2-t1)/freq
frame_rate_calc = 1/time1
# Press &#39;q&#39; to quit
if cv2.waitKey(1) == ord(&#39;q&#39;):
break

camera.release()

cv2.destroyAllWindows()


I do not know what are the equivalent syntaxes for Tensorflow and Tensorflow lite. Any help will be appreciated! Cheers!
I have tried just straight up swapping the .pb model for a .tflite model but that did not work at all.
</details>
# 答案1
**得分**: 1
要运行TFLite模型,您需要使用TFLite解释器。以下是示例代码:
```python
import tensorflow as tf
# 加载模型
interpreter = tf.lite.Interpreter(model_path='your_model.tflite')
interpreter.allocate_tensors()
# 获取输入/输出详情
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# 加载您的数据
input_data = <your_data>
# 设置模型输入
interpreter.set_tensor(input_details[0]['index'], input_data)
# 运行模型
interpreter.invoke()
# 获取输出
output_data = interpreter.get_tensor(output_details[0]['index'])
英文:

To run TFLite model you need to use TFLite interpreter. Here is example:

import tensorflow as tf
# load model
interpreter = tf.lite.Interpreter(model_path=&#39;your_model.tflite&#39;)
interpreter.allocate_tensors()
# get input/output details
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# load your data
input_data = &lt;your_data&gt;
# set model input
interpreter.set_tensor(input_details[0][&#39;index&#39;], input_data)
# run model
interpreter.invoke()
# get output
output_data = interpreter.get_tensor(output_details[0][&#39;index&#39;])

huangapple
  • 本文由 发表于 2023年3月9日 14:48:21
  • 转载请务必保留本文链接:https://go.coder-hub.com/75681253.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定