pytorch转ONNX,ONNX转tensorrt
1、下载安装pytorch,collections包validation.py#!/usr/bin/env python3# -*- coding: utf-8 -*-import logginglogger = logging.getLogger(__name__)def flatten_dict(obj,out=None):assert isinstance(obj, dict), 'dict
·
1、下载安装pytorch,collections包
validation.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import logging
logger = logging.getLogger(__name__)
def flatten_dict(obj,
out=None):
assert isinstance(obj, dict), 'dict type required'
if out is None:
out = type(obj)()
for key, value in obj.items():
if isinstance(value, dict):
flatten_dict(value, out)
else:
assert key not in out, 'key conflicted'
out[key] = value
return out
def ensure_list(obj):
if isinstance(obj, (list, tuple, set)):
return list(obj)
return [obj]
def validate(onnx_model_filename, golden_data_filename,
atol=1e-3, rtol=1e-3,
**kwargs):
"""
inference model in tensorrt, validate with given golden data
"""
import numpy as np
import pycuda.autoinit
import pycuda.driver as cuda
import tensorrt as trt
trt_logger = trt.Logger()
builder = trt.Builder(trt_logger)
network = builder.create_network()
parser = trt.OnnxParser(network, trt_logger)
onnx_model_proto_str = open(onnx_model_filename, 'rb').read()
parser.parse(onnx_model_proto_str)
logger.info('model parsing passed')
builder.max_batch_size = 1
engine = builder.build_cuda_engine(network)
logger.info('build engine passed')
logger.info('using golden data %s', golden_data_filename)
if golden_data_filename.endswith('.npz'):
test_data = np.load(golden_data_filename, encoding='bytes', allow_pickle=True)
input_data = test_data['inputs'].tolist()
output_data = test_data['outputs'].tolist()
else:
test_data = np.load(golden_data_filename, encoding='bytes').tolist()
input_data = test_data['inputs']
output_data = test_data['outputs']
input_data = flatten_dict(input_data)
output_data = flatten_dict(output_data)
input_names = input_data.keys()
output_names = output_data.keys()
logger.info('golden data contain %d inputs or outputs',
len(input_data) + len(output_data))
input_device_data = {name: cuda.to_device(value) for name, value in input_data.items()}
output_device_data = {name: cuda.mem_alloc(value.nbytes) for name, value in output_data.items()}
output_host_data = {name: cuda.pagelocked_empty_like(value) for name, value in output_data.items()}
logger.info('data transfered to device')
with engine.create_execution_context() as context:
stream = cuda.Stream()
# for name in input_names:
# cuda.memcpy_htod_async(input_device_data[name], input_host_data[name],
# stream=stream)
device_data = list(input_device_data.values()) + list(output_device_data.values())
context.execute_async(bindings=[int(data) for data in device_data],
stream_handle=stream.handle)
for name in output_names:
cuda.memcpy_dtoh_async(output_host_data[name], output_device_data[name],
stream=stream)
stream.synchronize()
logger.info('execution passed')
# validate
passed = True
for name in output_names:
pr = output_host_data[name]
gt = output_data[name]
logger.info('testing on output {} ...'.format(name))
try:
np.testing.assert_allclose(pr, gt,
rtol=rtol, atol=atol,
equal_nan=False, verbose=True)
except AssertionError as e:
passed = False
logger.error('failed: %s\n', e)
if passed:
logger.info('accuracy passed')
else:
logger.info('accuracy not passed')
return passed
def main():
import argparse
parser = argparse.ArgumentParser(description='tensorrt_validate',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('model', nargs=1,
help='path to model.onnx',
)
parser.add_argument('--debug', '-d', action='store_true',
help='enable debug logging and checking',
)
parser.add_argument('--test_data', '-t', type=str, default='',
help='I/O golden data for validation, e.g. test.npy, test.npz',
)
parser.add_argument('--atol', '-p', type=float, default=1e-4,
help='assertion absolute tolerance for validation',
)
parser.add_argument('--rtol', type=float, default=1e-3,
help='assertion relative tolerance for validation',
)
args = parser.parse_args()
logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
logging_level = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(format=logging_format, level=logging_level)
# debug = args.debug
onnx_model_filename = args.model[0]
golden_data_filename = args.test_data
atol, rtol = args.atol, args.rtol
validate(onnx_model_filename, golden_data_filename,
atol=atol, rtol=rtol)
if __name__ == '__main__':
main()
torch_export_helper.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
import torch
from collections import OrderedDict as Dict
__all__ = [
'export_data',
'export_onnx_with_validation',
]
def ensure_list(obj):
if isinstance(obj, (list, tuple, set)):
return list(obj)
return [obj]
def ensure_tuple(obj):
if isinstance(obj, (tuple, list, set)):
return tuple(obj)
return (obj, )
def flatten_list(obj,
out=None):
assert isinstance(obj, list), 'list type required'
if out is None:
out = type(obj)()
for item in obj:
if isinstance(item, list):
flatten_list(item, out)
else:
out.append(item)
return out
def export_data(state_dict,
prefix=''):
"""
export binary data with meta text for raw C++ inference engines
"""
def str_(obj):
if isinstance(obj, (tuple, list, set)):
return str(obj)[1:-1].replace(' ', '')
return str(obj)
prefix_ = prefix + ('_' if prefix else '')
fp = open('{}.txt'.format(prefix or 'meta'), 'w')
for key, value in state_dict.items():
data = None
if torch.is_tensor(value):
data = value.data.cpu().numpy()
elif isinstance(value, np.ndarray):
data = value
if data is not None:
data.tofile('{}{}.bin'.format(prefix_, key))
fp.write('{}.dtype={}\n'.format(key, str_(data.dtype.name)))
fp.write('{}.shape={}\n'.format(key, str_(data.shape)))
else:
fp.write('{}={}\n'.format(key, str_(value)))
fp.close()
def export_onnx_with_validation(model, inputs, export_basepath,
input_names=None, output_names=None,
use_npz=True,
*args, **kwargs):
"""
export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file
"""
is_tuple_or_list = lambda x: isinstance(x, (tuple, list))
def tensors_to_arrays(tensors):
if torch.is_tensor(tensors):
return tensors.data.cpu().numpy()
arrays = []
for tensor in tensors:
arrays.append(tensors_to_arrays(tensor))
return arrays
def zip_dict(keys, values):
if keys is None: # can be None
keys = range(len(values))
ret = Dict()
for idx, (key, value) in enumerate(zip(keys, values)):
is_key_list = is_tuple_or_list(key)
is_value_list = is_tuple_or_list(value)
assert is_key_list == is_value_list, 'keys and values mismatch'
if is_value_list:
ret[str(idx)] = zip_dict(key, value)
else:
ret[key] = value
return ret
torch_inputs = ensure_tuple(inputs) # WORKAROUND: for torch.onnx
outputs = torch.onnx.export(model, torch_inputs, export_basepath + '.onnx',
input_names=None if input_names is None else flatten_list(input_names),
output_names=None if output_names is None else flatten_list(output_names),
*args, **kwargs)
if outputs is None: # WORKAROUND: for torch.onnx
training = kwargs.get('training', False)
with torch.onnx.set_training(model, training):
outputs = model(*inputs)
torch_outputs = ensure_tuple(outputs)
inputs = zip_dict(input_names, tensors_to_arrays(torch_inputs))
outputs = zip_dict(output_names, tensors_to_arrays(torch_outputs))
if use_npz:
np.savez(export_basepath + '.npz', inputs=inputs, outputs=outputs)
else:
np.save(export_basepath + '.npy',
np.array(Dict(inputs=inputs, outputs=outputs)))
return torch_outputs
更多推荐
所有评论(0)