Struct GraphLoadConfig

Struct Documentation

struct mgb::serialization::GraphLoadConfig

config for loading a whole graph; setup in GraphLoader

Public Types

using CompNodeMapper = thin_function<void(CompNode::Locator&)>
using TensorValueLoader = thin_function<void(void *ptr, const TensorLayout &layout, InputFile &fin)>

load tensor value into given memory address

  • ptr: dest pointer or nullptr; if it is NULL, fin should be advanced (by calling InputFile::skip()) to skip storage of this tensor

  • layout: tensor layout, guaranteed to be contiguous

using TensorModifier = thin_function<void(const std::string &name, bool has_value, HostTensorND &tensor)>

callback to modify loaded tensors

  • name: tensor name; it is empty for unnamed tensors

  • has_value: whether tensor value is dumped (params usually have value)

  • tensor: the tensor that can be modified inplace

using OprLoaderMaker = thin_function<OprLoader(const std::string&)>

Public Functions

GraphLoadConfig(const CompNodeMapper &comp_node_mapper_ = {}, const OprLoaderMaker &opr_loader_maker_ = {}, const std::shared_ptr<UserDataContainer> &user_data_ = {}, const std::shared_ptr<ComputingGraph> &comp_graph_ = {}, const TensorValueLoader tensor_value_loader_ = {})

Public Members

bool const_var_shape = false

whether to make all SharedDeviceTensor and Host2DeviceCopy shapes immutable so static inference can be eagerly performed; this can be used to reduce memory usage; tensor_modifier can be used to modify the shape

TensorModifier tensor_modifier

callback to modify loaded tensors before they are inserted into the graph

CompNodeMapper comp_node_mapper

callback to modify comp node locator inplace

OprLoaderMaker opr_loader_maker

map from any identifier to an opr loader; see OprRegistry::add_using_dynamic_loader

std::shared_ptr<UserDataContainer> user_data

extra user data to be passed by load caller into opr load implementations; useful for implementing nested opr load

std::shared_ptr<ComputingGraph> comp_graph

computing graph to add new oprs; a new graph would be created if it is null

TensorValueLoader tensor_value_loader

tensor value loader that must match tensor_value_dumper used in GraphDumpConfig

Public Static Functions

void default_tensor_value_loader(void *ptr, const TensorLayout &layout, InputFile &fin)

a fallback to implement custom tensor value reader; it just reads the raw tensor value from input file. Implemented in serializer.cpp