1、下载安装pytorch,collections包

validation.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import logging


logger = logging.getLogger(__name__)


def flatten_dict(obj,
                 out=None):
    assert isinstance(obj, dict), 'dict type required'

    if out is None:
        out = type(obj)()
    for key, value in obj.items():
        if isinstance(value, dict):
            flatten_dict(value, out)
        else:
            assert key not in out, 'key conflicted'
            out[key] = value
    return out


def ensure_list(obj):
    if isinstance(obj, (list, tuple, set)):
        return list(obj)
    return [obj]


def validate(onnx_model_filename, golden_data_filename,
             atol=1e-3, rtol=1e-3,
             **kwargs):
    """
    inference model in tensorrt, validate with given golden data
    """

    import numpy as np
    import pycuda.autoinit
    import pycuda.driver as cuda
    import tensorrt as trt

    trt_logger = trt.Logger()
    builder = trt.Builder(trt_logger)
    network = builder.create_network()
    parser = trt.OnnxParser(network, trt_logger)
    onnx_model_proto_str = open(onnx_model_filename, 'rb').read()
    parser.parse(onnx_model_proto_str)
    logger.info('model parsing passed')

    builder.max_batch_size = 1
    engine = builder.build_cuda_engine(network)
    logger.info('build engine passed')

    logger.info('using golden data %s', golden_data_filename)
    if golden_data_filename.endswith('.npz'):
        test_data = np.load(golden_data_filename, encoding='bytes', allow_pickle=True)
        input_data = test_data['inputs'].tolist()
        output_data = test_data['outputs'].tolist()
    else:
        test_data = np.load(golden_data_filename, encoding='bytes').tolist()
        input_data = test_data['inputs']
        output_data = test_data['outputs']
    input_data = flatten_dict(input_data)
    output_data = flatten_dict(output_data)
    input_names = input_data.keys()
    output_names = output_data.keys()
    logger.info('golden data contain %d inputs or outputs',
                   len(input_data) + len(output_data))

    input_device_data = {name: cuda.to_device(value) for name, value in input_data.items()}
    output_device_data = {name: cuda.mem_alloc(value.nbytes) for name, value in output_data.items()}
    output_host_data = {name: cuda.pagelocked_empty_like(value) for name, value in output_data.items()}
    logger.info('data transfered to device')

    with engine.create_execution_context() as context:
        stream = cuda.Stream()

#        for name in input_names:
#            cuda.memcpy_htod_async(input_device_data[name], input_host_data[name],
#                                   stream=stream)

        device_data = list(input_device_data.values()) + list(output_device_data.values())
        context.execute_async(bindings=[int(data) for data in device_data],
                              stream_handle=stream.handle)

        for name in output_names:
            cuda.memcpy_dtoh_async(output_host_data[name], output_device_data[name],
                                   stream=stream)

        stream.synchronize()

    logger.info('execution passed')

    # validate
    passed = True
    for name in output_names:
        pr = output_host_data[name]
        gt = output_data[name]
        logger.info('testing on output {} ...'.format(name))
        try:
            np.testing.assert_allclose(pr, gt,
                                       rtol=rtol, atol=atol,
                                       equal_nan=False, verbose=True)
        except AssertionError as e:
            passed = False
            logger.error('failed: %s\n', e)
    if passed:
        logger.info('accuracy passed')
    else:
        logger.info('accuracy not passed')

    return passed


def main():
    import argparse

    parser = argparse.ArgumentParser(description='tensorrt_validate',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
                                     )
    parser.add_argument('model', nargs=1,
                        help='path to model.onnx',
                        )
    parser.add_argument('--debug', '-d', action='store_true',
                        help='enable debug logging and checking',
                        )
    parser.add_argument('--test_data', '-t', type=str, default='',
                        help='I/O golden data for validation, e.g. test.npy, test.npz',
                        )
    parser.add_argument('--atol', '-p', type=float, default=1e-4,
                        help='assertion absolute tolerance for validation',
                        )
    parser.add_argument('--rtol', type=float, default=1e-3,
                        help='assertion relative tolerance for validation',
                        )
    args = parser.parse_args()

    logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
    logging_level = logging.DEBUG if args.debug else logging.INFO
    logging.basicConfig(format=logging_format, level=logging_level)

    #	debug = args.debug
    onnx_model_filename = args.model[0]
    golden_data_filename = args.test_data
    atol, rtol = args.atol, args.rtol

    validate(onnx_model_filename, golden_data_filename,
             atol=atol, rtol=rtol)


if __name__ == '__main__':
    main()












torch_export_helper.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import numpy as np
import torch

from collections import OrderedDict as Dict


__all__ = [
    'export_data',
    'export_onnx_with_validation',
]


def ensure_list(obj):
    if isinstance(obj, (list, tuple, set)):
        return list(obj)
    return [obj]


def ensure_tuple(obj):
    if isinstance(obj, (tuple, list, set)):
        return tuple(obj)
    return (obj, )


def flatten_list(obj,
                 out=None):
    assert isinstance(obj, list), 'list type required'

    if out is None:
        out = type(obj)()
    for item in obj:
        if isinstance(item, list):
            flatten_list(item, out)
        else:
            out.append(item)
    return out


def export_data(state_dict,
                prefix=''):
    """
    export binary data with meta text for raw C++ inference engines
    """

    def str_(obj):
        if isinstance(obj, (tuple, list, set)):
            return str(obj)[1:-1].replace(' ', '')
        return str(obj)

    prefix_ = prefix + ('_' if prefix else '')
    fp = open('{}.txt'.format(prefix or 'meta'), 'w')
    for key, value in state_dict.items():
        data = None
        if torch.is_tensor(value):
            data = value.data.cpu().numpy()
        elif isinstance(value, np.ndarray):
            data = value
        if data is not None:
            data.tofile('{}{}.bin'.format(prefix_, key))
            fp.write('{}.dtype={}\n'.format(key, str_(data.dtype.name)))
            fp.write('{}.shape={}\n'.format(key, str_(data.shape)))
        else:
            fp.write('{}={}\n'.format(key, str_(value)))
    fp.close()


def export_onnx_with_validation(model, inputs, export_basepath,
                                input_names=None, output_names=None,
                                use_npz=True,
                                *args, **kwargs):
    """
    export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file
    """

    is_tuple_or_list = lambda x: isinstance(x, (tuple, list))

    def tensors_to_arrays(tensors):
        if torch.is_tensor(tensors):
            return tensors.data.cpu().numpy()
        arrays = []
        for tensor in tensors:
            arrays.append(tensors_to_arrays(tensor))
        return arrays

    def zip_dict(keys, values):
        if keys is None: # can be None
            keys = range(len(values))
        ret = Dict()
        for idx, (key, value) in enumerate(zip(keys, values)):
            is_key_list = is_tuple_or_list(key)
            is_value_list = is_tuple_or_list(value)
            assert is_key_list == is_value_list, 'keys and values mismatch'
            if is_value_list:
                ret[str(idx)] = zip_dict(key, value)
            else:
                ret[key] = value
        return ret

    torch_inputs = ensure_tuple(inputs) # WORKAROUND: for torch.onnx
    outputs = torch.onnx.export(model, torch_inputs, export_basepath + '.onnx',
            input_names=None if input_names is None else flatten_list(input_names),
            output_names=None if output_names is None else flatten_list(output_names),
            *args, **kwargs)
    if outputs is None: # WORKAROUND: for torch.onnx
        training = kwargs.get('training', False)
        with torch.onnx.set_training(model, training):
            outputs = model(*inputs)
    torch_outputs = ensure_tuple(outputs)

    inputs = zip_dict(input_names, tensors_to_arrays(torch_inputs))
    outputs = zip_dict(output_names, tensors_to_arrays(torch_outputs))
    if use_npz:
        np.savez(export_basepath + '.npz', inputs=inputs, outputs=outputs)
    else:
        np.save(export_basepath + '.npy',
                np.array(Dict(inputs=inputs, outputs=outputs)))

    return torch_outputs

 

Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐