# -*- coding: utf-8 -*-
import threading
import warnings
import numpy as np
from .base import *
from .struct import *
from .tensor import *
[文档]class TensorBatchCollector:
"""
A tensor utils is used to collect many single batch tensor to a multi batch
size tensor, when the multi batch size tensor collect finish, the result
tensor can be get and send to the model input for forwarding.
when collect single batch tensor, the single batch tensor is no need in the
same device_type and device_id with the result tensor, however the dtype must
match and the shape must match except the highest dimension.
Args:
shape: the multi batch size tensor shape, After collection, the result
tensor shape.
dtype(LiteDataType): the datatype of the single batch tensor and the
result tensor, default value is LiteDataType.LITE_INT8.
device_type(LiteDeviceType): the target device type the result tensor
will allocate, default value is LiteDeviceType.LITE_CUDA.
device_id: the device id the result tensor will allocate, default 0.
is_pinned_host: Whether the memory is pinned memory, refer to CUDA
pinned memory, default False.
tensor(LiteTensor): the result tensor, user can also create the multi
batch size tensor and then create the TensorBatchColletor, if tensor is
not None, all the member, such as shape, dtype, device_type,
device_id, is_pinned_host will get from the tensor, if the tensor is
None and the result tensor will create by the TensorBatchCollector,
default is None.
Note:
when collect tensor, the single batch tensor or array shape must match the
result tensor shape except the batch size dimension (the highest dimension)
Examples:
.. code-block:: python
import numpy as np
batch_tensor = TensorBatchCollector([4, 8, 8])
arr = np.ones([8, 8], "int8")
for i in range(4):
batch_tensor.collect(arr)
arr += 1
data = batch_tensor.to_numpy()
assert data.shape[0] == 4
assert data.shape[1] == 8
assert data.shape[2] == 8
for i in range(4):
for j in range(64):
assert data[i][j // 8][j % 8] == i + 1
"""
def __init__(
self,
shape,
dtype=LiteDataType.LITE_INT8,
device_type=LiteDeviceType.LITE_CUDA,
device_id=0,
is_pinned_host=False,
tensor=None,
):
self._mutex = threading.Lock()
self.dev_type = device_type
self.is_pinned_host = is_pinned_host
self.dev_id = device_id
self.shape = shape
self.dtype = LiteLayout(dtype=dtype).data_type
self._free_list = list(range(self.shape[0]))
if tensor is not None:
assert (
tensor.layout.shapes[0 : tensor.layout.ndim] == shape
), "The tensor set to TensorBatchCollector is not right."
self._tensor = tensor
self.dtype = tensor.layout.data_type
self.device_type = tensor.device_type
self.device_id = tensor.device_type
else:
self._tensor = LiteTensor(
LiteLayout(shape, dtype), device_type, device_id, is_pinned_host
)
[文档] def collect_id(self, array, batch_id):
"""
Collect a single batch through an array and store the array data to the
specific batch_id.
Args:
array: an array maybe LiteTensor or numpy ndarray, the shape of
array must match the result tensor shape except the highest
dimension.
batch_id: the batch id to store the array data to the result tensor,
if the batch_id has already collected, a warning will generate.
"""
# get the batch index
with self._mutex:
if batch_id in self._free_list:
self._free_list.remove(batch_id)
else:
warnings.warn(
"batch {} has been collected, please call free before collected it again.".format(
batch_id
)
)
self._collect_with_id(array, batch_id)
def _collect_with_id(self, array, batch_id):
if isinstance(array, np.ndarray):
shape = array.shape
assert list(shape) == self.shape[1:]
in_dtype = ctype_to_lite_dtypes[np.ctypeslib.as_ctypes_type(array.dtype)]
assert in_dtype == self.dtype
# get the subtensor
subtensor = self._tensor.slice([batch_id], [batch_id + 1])
if subtensor.device_type == LiteDeviceType.LITE_CPU:
subtensor.set_data_by_copy(array)
else:
pinned_tensor = LiteTensor(
subtensor.layout, self.dev_type, self.dev_id, True
)
pinned_tensor.set_data_by_share(array)
subtensor.copy_from(pinned_tensor)
else:
assert isinstance(array, LiteTensor)
ndim = array.layout.ndim
shape = list(array.layout.shapes)[0:ndim]
assert list(shape) == self.shape[1:]
in_dtype = array.layout.data_type
assert in_dtype == self.dtype
# get the subtensor
subtensor = self._tensor.slice([batch_id], [batch_id + 1])
subtensor.copy_from(array)
return batch_id
[文档] def collect(self, array):
"""
Collect a single batch through an array and store the array data to an
empty batch, the empty batch is the front batch id in free list.
Args:
array: an array maybe LiteTensor or numpy ndarray, the shape must
match the result tensor shape except the highest dimension
"""
with self._mutex:
if len(self._free_list) == 0:
warnings.warn(
"all batch has been collected, please call free before collect again."
)
return -1
idx = self._free_list.pop(0)
return self._collect_with_id(array, idx)
[文档] def collect_by_ctypes(self, data, length):
"""
Collect a single batch through an ctypes memory buffer and store the
ctypes memory data to an empty batch, the empty batch is the front
batch id in free list.
Args:
array: an array maybe LiteTensor or numpy ndarray, the shape must
match the result tensor shape except the highest dimension
"""
with self._mutex:
if len(self._free_list) == 0:
return -1
idx = self._free_list.pop(0)
# get the subtensor
subtensor = self._tensor.slice([idx], [idx + 1])
if subtensor.device_type == LiteDeviceType.LITE_CPU:
subtensor.set_data_by_copy(data, length)
else:
pinned_tensor = LiteTensor(
subtensor.layout, self.dev_type, self.dev_id, True
)
pinned_tensor.set_data_by_share(data, length)
subtensor.copy_from(pinned_tensor)
[文档] def free(self, indexes):
"""
free the batch ids in the indexes, after the batch id is freed, it can
be collected again without warning.
Args:
indexes: a list of to be freed batch id
"""
with self._mutex:
for i in indexes:
if i in self._free_list:
warnings.warn(
"batch id {} has not collected before free it.".format(i)
)
self._free_list.remove(i)
self._free_list.extend(indexes)
[文档] def get_tensor_at(self, idx):
"""
get the tensor from the internal big tensor by the idx, make sure the
idx is not freed, return the tensor
Args:
idx: the tensor index in the internal big tensor
"""
assert idx < self.shape[0], "the idx specific the tensor is out of range."
if idx in self._free_list:
warnings.warn(
"tensor with batch id {} has not collected before get it.".format(idx)
)
return self._tensor.slice([idx], [idx + 1])
[文档] def get(self):
"""
After finish collection, get the result tensor
"""
return self._tensor
[文档] def to_numpy(self):
"""
Convert the result tensor to a numpy ndarray
"""
return self._tensor.to_numpy()