Skip to content

Commit

Permalink
New PR for "ultralytics#7736"
Browse files Browse the repository at this point in the history
  • Loading branch information
triple-Mu committed Sep 4, 2022
1 parent 7aa263c commit 99de36c
Show file tree
Hide file tree
Showing 4 changed files with 338 additions and 1 deletion.
16 changes: 16 additions & 0 deletions examples/export.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/usr/bin/env bash
cd ../
mkdir -p weights

# download official weights
wget https://gh.ddlc.top/https://github.com/ultralytics/yolov5/releases/download/v6.1/yolov5s.pt -P weights
# export yolov5s.onnx
python3 export.py --weights weights/yolov5s.pt --include onnx engine --nms
mv weights/yolov5s.onnx ./examples/yolov5s_nms.onnx
cd examples
trtexec --onnx=./yolov5s_nms.onnx --saveEngine=./yolov5s_nms_fp16.engine --fp16

# result test
wget https://oneflow-static.oss-cn-beijing.aliyuncs.com/tripleMu/image1.jpg
python3 trt_infer.py
trtexec --loadEngine=./yolov5s_nms_fp16.engine --verbose --useCudaGraph --noDataTransfers --shapes=images:1x3x640x640
99 changes: 99 additions & 0 deletions examples/trt_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import sys

import cv2

sys.path.append('../')
import random
import time
from collections import OrderedDict, namedtuple

import numpy as np
import tensorrt as trt
import torch
from PIL import Image

from utils.augmentations import letterbox

names = [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant',
'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle',
'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet',
'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator',
'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
colors = {name: [random.randint(0, 255) for _ in range(3)] for i, name in enumerate(names)}

w = './yolov5s_nms_fp16.engine'
image_path = './image1.jpg'
device = torch.device('cuda:0')

# Infer TensorRT Engine
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
logger = trt.Logger(trt.Logger.INFO)
trt.init_libnvinfer_plugins(logger, namespace="")
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
model = runtime.deserialize_cuda_engine(f.read())
bindings = OrderedDict()
fp16 = False # default updated below
for index in range(model.num_bindings):
name = model.get_binding_name(index)
dtype = trt.nptype(model.get_binding_dtype(index))
shape = tuple(model.get_binding_shape(index))
data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
if model.binding_is_input(index) and dtype == np.float16:
fp16 = True
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
context = model.create_execution_context()

image = cv2.imread(image_path)
image, ratio, dwdh = letterbox(image, auto=False)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

image_copy = image.copy()

image = image.transpose((2, 0, 1))
image = np.expand_dims(image, 0)
image = np.ascontiguousarray(image)
im = torch.from_numpy(image).to(device)
im = im.float()
im /= 255

# warmup for 10 times
for _ in range(10):
tmp = torch.randn(1, 3, 640, 640).to(device)
binding_addrs['images'] = int(tmp.data_ptr())
context.execute_v2(list(binding_addrs.values()))

start = time.perf_counter()
binding_addrs['images'] = int(im.data_ptr())
context.execute_v2(list(binding_addrs.values()))
print(f'Cost {time.perf_counter()-start} s')

nums = bindings['num_dets'].data
boxes = bindings['det_boxes'].data
scores = bindings['det_scores'].data
classes = bindings['det_classes'].data

print(nums)
print(boxes)
print(scores)
print(classes)

num = int(nums[0][0])
box_img = boxes[0, :num].round().int()
score_img = scores[0, :num]
clss_img = classes[0, :num]
for i, (box, score, clss) in enumerate(zip(box_img, score_img, clss_img)):
name = names[clss]
color = colors[name]
cv2.rectangle(image_copy, box[:2].tolist(), box[2:].tolist(), color, 2)
cv2.putText(image_copy,
name, (int(box[0]), int(box[1]) - 2),
cv2.FONT_HERSHEY_SIMPLEX,
0.75, [225, 255, 255],
thickness=2)

Image.fromarray(image_copy).show()
68 changes: 67 additions & 1 deletion export.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,66 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst
return f, model_onnx


@try_export
def export_onnx_for_backend(model, im, file, opset, nms_cfg, dynamic, simplify, prefix=colorstr('ONNX:')):
# YOLOv5 ONNX export
check_requirements(('onnx',))
import onnx

LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
f = file.with_suffix('.onnx')

from models.common import End2End
model = End2End(model, *nms_cfg, device=im.device)
if nms_cfg[-1] == 'ort':
output_names = ['outputs']
elif nms_cfg[-1] == 'trt':
output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes']

if dynamic and nms_cfg[-1] == 'ort':
dynamic_cfg = {n: {0: 'batch'} for n in output_names}
elif dynamic and nms_cfg[-1] == 'trt':
dynamic_cfg = {n: {0: 'batch'} for n in output_names}

torch.onnx.export(
model.cpu() if dynamic else model, # --dynamic only compatible with cpu
im.cpu() if dynamic else im,
f,
verbose=False,
opset_version=opset,
training=torch.onnx.TrainingMode.EVAL,
do_constant_folding=True,
input_names=['images'],
output_names=output_names,
dynamic_axes=dynamic_cfg if dynamic else None)

# Checks
model_onnx = onnx.load(f) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model

# Metadata
d = {'stride': int(max(model.stride)), 'names': model.names}
for k, v in d.items():
meta = model_onnx.metadata_props.add()
meta.key, meta.value = k, str(v)
onnx.save(model_onnx, f)

# Simplify
if simplify:
try:
cuda = torch.cuda.is_available()
check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1'))
import onnxsim

LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
model_onnx, check = onnxsim.simplify(model_onnx)
assert check, 'assert check failed'
onnx.save(model_onnx, f)
except Exception as e:
LOGGER.info(f'{prefix} simplifier failure: {e}')
return f, model_onnx


@try_export
def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')):
# YOLOv5 OpenVINO export
Expand Down Expand Up @@ -453,6 +513,7 @@ def run(
verbose=False, # TensorRT: verbose log
workspace=4, # TensorRT: workspace size (GB)
nms=False, # TF: add NMS to model
backend='ort', # Backend for export NMS
agnostic_nms=False, # TF: add agnostic NMS to model
topk_per_class=100, # TF.js NMS: topk per class to keep
topk_all=100, # TF.js NMS: topk for all classes to keep
Expand All @@ -465,6 +526,7 @@ def run(
flags = [x in include for x in fmts]
assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}'
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags # export booleans
end2end, onnx = onnx and nms, onnx and not nms
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights

# Load PyTorch model
Expand Down Expand Up @@ -500,7 +562,7 @@ def run(
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")

# Exports
f = [''] * 10 # exported filenames
f = [''] * 11 # exported filenames
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
if jit:
f[0], _ = export_torchscript(model, im, file, optimize)
Expand Down Expand Up @@ -539,6 +601,9 @@ def run(
if tfjs:
f[9], _ = export_tfjs(file)

if end2end:
nms_cfg = [topk_all, iou_thres, conf_thres, backend]
f[10], _ = export_onnx_for_backend(model, im, file, opset, nms_cfg, dynamic, simplify)
# Finish
f = [str(x) for x in f if x] # filter out '' and None
if any(f):
Expand Down Expand Up @@ -571,6 +636,7 @@ def parse_opt():
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
parser.add_argument('--nms', action='store_true', help='TF: add NMS to model')
parser.add_argument('--backend', type=str, default='ort', help='Backend for export NMS')
parser.add_argument('--agnostic-nms', action='store_true', help='TF: add agnostic NMS to model')
parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
Expand Down
156 changes: 156 additions & 0 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import math
import platform
import random
import warnings
from collections import OrderedDict, namedtuple
from copy import copy
Expand Down Expand Up @@ -777,3 +778,158 @@ def forward(self, x):
if isinstance(x, list):
x = torch.cat(x, 1)
return self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))


class ORT_NMS(torch.autograd.Function):

@staticmethod
def forward(ctx,
boxes,
scores,
max_output_boxes_per_class=torch.tensor([100]),
iou_threshold=torch.tensor([0.45]),
score_threshold=torch.tensor([0.25])):
device = boxes.device
batch = scores.shape[0]
num_det = random.randint(0, 100)
batches = torch.randint(0, batch, (num_det,)).sort()[0].to(device)
idxs = torch.arange(100, 100 + num_det).to(device)
zeros = torch.zeros((num_det,), dtype=torch.int64).to(device)
selected_indices = torch.cat([batches[None], zeros[None], idxs[None]], 0).T.contiguous()
selected_indices = selected_indices.to(torch.int64)
return selected_indices

@staticmethod
def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold):
return g.op("NonMaxSuppression", boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold)


class TRT_NMS(torch.autograd.Function):

@staticmethod
def forward(
ctx,
boxes,
scores,
background_class=-1,
box_coding=1,
iou_threshold=0.45,
max_output_boxes=100,
plugin_version="1",
score_activation=0,
score_threshold=0.25,
):
batch_size, num_boxes, num_classes = scores.shape
num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
det_boxes = torch.randn(batch_size, max_output_boxes, 4)
det_scores = torch.randn(batch_size, max_output_boxes)
det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)

return num_det, det_boxes, det_scores, det_classes

@staticmethod
def symbolic(g,
boxes,
scores,
background_class=-1,
box_coding=1,
iou_threshold=0.45,
max_output_boxes=100,
plugin_version="1",
score_activation=0,
score_threshold=0.25):
out = g.op("TRT::EfficientNMS_TRT",
boxes,
scores,
background_class_i=background_class,
box_coding_i=box_coding,
iou_threshold_f=iou_threshold,
max_output_boxes_i=max_output_boxes,
plugin_version_s=plugin_version,
score_activation_i=score_activation,
score_threshold_f=score_threshold,
outputs=4)
nums, boxes, scores, classes = out
return nums, boxes, scores, classes


class ONNX_ORT(nn.Module):

def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, device=None):
super().__init__()
self.device = device if device else torch.device("cpu")
self.max_obj = torch.tensor([max_obj]).to(device)
self.iou_threshold = torch.tensor([iou_thres]).to(device)
self.score_threshold = torch.tensor([score_thres]).to(device)
self.max_wh = 7680
self.convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
dtype=torch.float32,
device=self.device)

def forward(self, x):
box = x[:, :, :4]
conf = x[:, :, 4:5]
score = x[:, :, 5:]
score *= conf
box @= self.convert_matrix
objScore, objCls = score.max(2, keepdim=True)
dis = objCls.float() * self.max_wh
nmsbox = box + dis
objScore1 = objScore.transpose(1, 2).contiguous()
selected_indices = ORT_NMS.apply(nmsbox, objScore1, self.max_obj, self.iou_threshold, self.score_threshold)
X, Y = selected_indices[:, 0], selected_indices[:, 2]
resBoxes = box[X, Y, :]
resClasses = objCls[X, Y, :].float()
resScores = objScore[X, Y, :]
X = X.unsqueeze(1).float()
return torch.cat([X, resBoxes, resClasses, resScores], 1)


class ONNX_TRT(nn.Module):

def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, device=None):
super().__init__()
self.device = device if device else torch.device('cpu')
self.background_class = -1,
self.box_coding = 1,
self.iou_threshold = iou_thres
self.max_obj = max_obj
self.plugin_version = '1'
self.score_activation = 0
self.score_threshold = score_thres

def forward(self, x):
box = x[:, :, :4]
conf = x[:, :, 4:5]
score = x[:, :, 5:]
score *= conf
num_det, det_boxes, det_scores, det_classes = TRT_NMS.apply(box, score, self.background_class, self.box_coding,
self.iou_threshold, self.max_obj,
self.plugin_version, self.score_activation,
self.score_threshold)
return num_det, det_boxes, det_scores, det_classes


class End2End(nn.Module):

def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, backend='ort', device=None):
super().__init__()
device = device if device else torch.device('cpu')
model.model[-1].dynamic = False
self.model = model.to(device)

if backend == 'trt':
self.patch_model = ONNX_TRT
elif backend == 'ort':
self.patch_model = ONNX_ORT
else:
raise NotImplementedError
self.end2end = self.patch_model(max_obj, iou_thres, score_thres, device)
self.end2end.eval()
self.stride = self.model.stride
self.names = self.model.names

def forward(self, x):
x = self.model(x)[0]
x = self.end2end(x)
return x

0 comments on commit 99de36c

Please sign in to comment.