Class EnableNchwxxPass

Inheritance Relationships

Base Type

Derived Type

Class Documentation

class mgb::gopt::EnableNchwxxPass : public mgb::gopt::TensorReformatPass

convert tensor format to nchwxx to speed up inference on certain devices

Subclassed by mgb::gopt::EnableNchw44DotPass

Public Types

enum TransType

the flag for conv to transform to nchwxx

Values:

enumerator TRANS_PURE_NCHWXX

weight and src all trans to nchwxx

enumerator TRANS_HYBIRD_NCHWXX

input is nchw, output is nchwxx

enumerator TRANS_NONE

no need trans

Public Functions

EnableNchwxxPass(size_t pack_c_size)
const char *name() const override
void set_name(std::string in_name)
void fill_opr_convert_fun(size_t pack_c_size)

Public Static Functions

std::unique_ptr<EnableNchwxxPass> make_nchwxx_converter(size_t pack_c_size)

make nchw -> nchwxx converter opt pass, pack_c_size is the x, like 4,8,16