TensorsDataClass
PyTorch Extension Library for organizing tensors in a form of a structured tree of dataclasses, with built-in support for advanced collating mechanisms. The batch creation process seamlessly solves issues like: sequences padding, un/flattening variable #objects per example into a single batch dimension, fixing within-example indices to be batch-based indices, auto-creation of sequences & collate masks, and more.
What pains TensorsDataClass aims to solve
... variable number of sequences per example where the sequence lengths may also be variable; lots of inputs usually gets messy - hard to handle, to name, to move to GPU, to abstract in a (X,Y) fashion ...
Installation
pip install tensors-data-class
Usage example
from tensors_data_class import *
@dataclasses.dataclass
class CodeExpressionTokensSequenceInputTensors(TensorsDataClass):
token_type: BatchFlattenedSeq
kos_token_index: BatchFlattenedTensor
identifier_index: BatchedFlattenedIndicesFlattenedTensor
@dataclasses.dataclass
class SymbolsInputTensors(TensorsDataClass):
symbols_identifier_indices: BatchedFlattenedIndicesFlattenedTensor
symbols_appearances_symbol_idx: BatchedFlattenedIndicesFlattenedTensor
symbols_appearances_expression_token_idx: BatchFlattenedTensor = None
symbols_appearances_cfg_expression_idx: BatchedFlattenedIndicesFlattenedTensor = None
@dataclasses.dataclass
class CFGPathsInputTensors(TensorsDataClass):
nodes_indices: BatchedFlattenedIndicesFlattenedSeq
edges_types: BatchFlattenedSeq
@dataclasses.dataclass
class CFGPathsNGramsInputTensors(TensorsDataClass):
nodes_indices: BatchedFlattenedIndicesFlattenedSeq
edges_types: BatchFlattenedSeq
@dataclasses.dataclass
class PDGInputTensors(TensorsDataClass):
cfg_nodes_control_kind: Optional[BatchFlattenedTensor] = None
cfg_nodes_has_expression_mask: Optional[BatchFlattenedTensor] = None
cfg_nodes_tokenized_expressions: Optional[CodeExpressionTokensSequenceInputTensors] = None
cfg_nodes_random_permutation: Optional[BatchedFlattenedIndicesPseudoRandomPermutation] = None
cfg_control_flow_paths: Optional[CFGPathsInputTensors] = None
cfg_control_flow_paths_ngrams: Optional[Dict[int, CFGPathsNGramsInputTensors]] = None
@dataclasses.dataclass
class IdentifiersInputTensors(TensorsDataClass):
sub_parts_batch: BatchFlattenedTensor
identifier_sub_parts_index: BatchedFlattenedIndicesFlattenedSeq
identifier_sub_parts_vocab_word_index: BatchFlattenedSeq
identifier_sub_parts_hashings: BatchFlattenedSeq
sub_parts_obfuscation: BatchFlattenedPseudoRandomSamplerFromRange
@dataclasses.dataclass
class MethodCodeInputTensors(TensorsDataClass):
example_hash: str
identifiers: IdentifiersInputTensors
symbols: SymbolsInputTensors
method_tokenized_code: Optional[CodeExpressionTokensSequenceInputTensors] = None
pdg: Optional[PDGInputTensors] = None
example1 = MethodCodeInputTensors(...)
example2 = MethodCodeInputTensors(...)
batch = MethodCodeInputTensors.collate([example1, example2])
print(batch)
Different types for different use-cases
TensorsDataClass
BatchFlattenedTensor
BatchFlattenedSeq
BatchedFlattenedIndicesFlattenedTensor
BatchedFlattenedIndicesFlattenedSeq
BatchedFlattenedIndicesPseudoRandomPermutationBatchedFlattenedIndicesPseudoRandomPermutation
BatchFlattenedPseudoRandomSamplerFromRange