Class TensorReformatPass

Inheritance Relationships

Base Type

  • public Pass

Derived Types

Class Documentation

class mgb::gopt::TensorReformatPass : public Pass

tensor format converter to accelerate inference speed on Nvidia platform

Subclassed by mgb::gopt::EnableCHWN4Pass, mgb::gopt::EnableNCHW4Pass, mgb::gopt::EnableNchwxxPass, mgb::gopt::EnableTensorCorePass

Public Functions

TensorReformatPass &set_var_replace_check_flag(VarReplaceCheckFlag flag)
void apply(OptState &opt) const override

Protected Attributes

ThinHashMap<Typeinfo*, thin_function<OperatorNodeBase*(OperatorNodeBase*, const VarNodeArray&)>> m_opr_replace_func
VarReplaceCheckFlag m_var_replace_check_flag = VarReplaceCheckFlag::CHECK_ALL