Source code for modalities.dataloader.dataset

from __future__ import annotations

from enum import Enum
from pathlib import Path
from typing import Optional

import jq
import numpy as np
from pydantic import BaseModel
from torch.utils.data.dataset import Dataset as TorchdataSet
from tqdm import tqdm
from transformers import BatchEncoding

from modalities.dataloader.create_packed_data import EmbeddedStreamData
from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader
from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper


[docs] class Dataset(TorchdataSet): """Dataset class.""" def __init__(self, raw_data_path: Path, sample_key: str): """ Initializes a Dataset object. Args: raw_data_path (Path): The path to the raw data. sample_key (str): The key used to access a sample in the dataset. """ self.raw_data_path = raw_data_path self.sample_key = sample_key
[docs] class DummySampleDataType(str, Enum): """ DummySampleDataType is an enumeration class that represents the data types for dummy samples. Attributes: FLOAT (str): Represents the float data type. INT (str): Represents the int data type. """ FLOAT = "float" INT = "int"
[docs] class DummySampleConfig(BaseModel): """ DummySampleConfig class represents the configuration for a dummy sample. Attributes: sample_key (str): The key of the sample. sample_shape (tuple[int, ...]): The shape of the sample. sample_type (DummySampleDataType): The type of the sample. """ sample_key: str sample_shape: tuple[int, ...] sample_type: DummySampleDataType
[docs] class DummyDatasetConfig(BaseModel): """ DummyDatasetConfig is a configuration class for defining a dummy dataset. Attributes: num_samples (int): The number of samples in the dataset. sample_definition (list[DummySampleConfig]): The list of sample definitions in the dataset. """ num_samples: int sample_definition: list[DummySampleConfig]
[docs] class DummyDataset(Dataset): """DummyDataset class.""" def __init__(self, num_samples: int, sample_definition: tuple[DummySampleConfig]): """ Initializes a DummyDataset object with the given number of samples and sample definition. When calling the __getitem__ method, the dataset will return a random sample based on the sample definition. Args: num_samples (int): The number of samples in the dataset. sample_definition (tuple[DummySampleConfig]): A list of tuples defining the dataset output. Each touple contains the sample key, shape and data type. Returns: None """ super().__init__(raw_data_path=None, sample_key=None) self.num_samples = num_samples self.sample_definition = sample_definition def __len__(self) -> int: """ Returns the length of the dataset. Returns: int: The number of samples in the dataset. """ return self.num_samples def __getitem__(self, idx: int) -> dict: """ Retrieves an item from the dataset at the specified index. Parameters: idx (int): The index of the item to retrieve. Returns: dict: A dictionary representing the retrieved item. Note: idx is not used. Instedam the method returns a random sample. """ return self._create_random_sample() def _create_random_sample(self) -> dict: # creates a random sample based on the sample definition sample = dict() for s in self.sample_definition: if s.sample_type == DummySampleDataType.FLOAT: data = np.random.randn(*s.sample_shape) elif s.sample_type == DummySampleDataType.INT: data = np.random.randint(low=0, high=512, size=s.sample_shape) else: raise NotImplementedError(f"DummyDataset does not support type { s.sample_type}") sample[s.sample_key] = data return sample
[docs] class MemMapDataset(Dataset): def __init__( self, raw_data_path: Path, tokenizer: TokenizerWrapper, sample_key: str, index_path: Optional[Path] = None, jq_pattern: str = ".text", ): """ Initializes the MemMapDataset object that represents a PyTorch Dataset with mmap support. Args: raw_data_path (Path): Path to a JSONL file, which holds text data. tokenizer (TokenizerWrapper): The tokenizer object that is required to tokenize text data. sample_key (str): The key to access the sample in the BatchEncoding. index_path (Optional[Path], optional): The path to the index file which indicates the start character/byte position of documents. Defaults to None. jq_pattern (str, optional): The jq pattern to filter the data. Results are afterwards tokenized and packed. Defaults to ".text". Returns: None """ super().__init__(raw_data_path=raw_data_path, sample_key=sample_key) self.reader = LargeFileLinesReader(self.raw_data_path, index_path=index_path) self.jq_filter = jq.compile(jq_pattern) self.tokenizer = tokenizer def __len__(self) -> int: """ Returns the length of the dataset. Returns: int: The length of the dataset. """ return len(self.reader) def __getitem__(self, idx: int) -> BatchEncoding: """ Retrieves the item at the given index. Args: idx (int): The index of the item to retrieve. Returns: BatchEncoding: The tokenized representation of the item. Raises: IndexError: If the index is out of bounds. """ if idx >= len(self.reader): raise IndexError("Index out of bounds") return self.tokenizer.tokenize(text=self.jq_filter.input_text(self.reader[idx]).first())
[docs] class PackedMemMapDatasetBase(Dataset): """PackedMemMapDatasetBase class.""" DATA_SECTION_LENGTH_IN_BYTES = EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES = EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES HEADER_SIZE_IN_BYTES = EmbeddedStreamData.HEADER_SIZE_IN_BYTES np_dtype_of_tokens_on_disk_from_bytes = { 1: np.dtype(np.uint8).newbyteorder("<"), 2: np.dtype(np.uint16).newbyteorder("<"), 4: np.dtype(np.uint32).newbyteorder("<"), } type_converter_for_torch = {1: np.uint8, 2: np.int32, 4: np.int64} def __init__(self, raw_data_path: Path, sample_key: str, load_index: Optional[bool] = True): """ Initializes the PackedMemMapDatasetBase object. Args: raw_data_path (Path): Path to a packed binary file (*.pbin). Use `modalities data pack_encoded_data` to create one based on a JSONL-file. sample_key (str): The key to access the sample in the BatchEncoding. load_index (bool, optional): Flag indicating whether to load the index. Defaults to True. Raises: RuntimeError: If the token representation with the given size is not supported. Returns: None Note: TODO: sample_key should support multi-modal features using separately encoded inputs, this needs to get replaced with a list of sample keys! """ super().__init__(raw_data_path=raw_data_path, sample_key=sample_key) self._embedded_stream_data = EmbeddedStreamData(raw_data_path, load_index=load_index) self._token_size_in_bytes = self._embedded_stream_data.token_size_in_bytes try: self._token_dtype_on_disk = self.np_dtype_of_tokens_on_disk_from_bytes[self._token_size_in_bytes] self._token_dtype_in_ram = self.type_converter_for_torch[self._token_size_in_bytes] except KeyError as e: raise RuntimeError( f"Encountered a required token representation with {self._token_size_in_bytes}," " which is not supported. Consider using a smaller vocabulary." ) from e self._index = self._generate_packing_index() @property def token_size_in_bytes(self) -> int: return self._token_size_in_bytes def _generate_packing_index(self) -> list[tuple[int, int]]: # Generates the packing index for the dataset. # The index is list of tuples, where each tuple contains the offset and length in bytes. return self._embedded_stream_data.index_base def __len__(self) -> int: """ Returns the length of the dataset. Returns: int: The length of the dataset. """ return len(self._index) def __getitem__(self, idx: int | slice) -> BatchEncoding: """ Retrieves the item at the given index or a slice of items. Args: idx (int | sclice): The index of the item to retrieve or slice of items. Returns: BatchEncoding: The retrieved item as a BatchEncoding object. Raises: ValueError: If the length of the sample in bytes is not a multiple of `self._token_size_in_bytes`. """ if not isinstance(idx, slice): # (offset_in_bytes, length_in_bytes) item_positions: list[tuple[int, int]] = [self._index[idx]] else: if idx.step is not None and idx.step != 1: raise ValueError("Slicing with step != 1 is not supported.") item_positions = self._index[idx] if len(item_positions) == 0: return BatchEncoding(data={self.sample_key: []}) # numpy frombuffer takes the memmap object as the buffer # and indices the data section with the given offset (in bytes) # and length in indices of type self._token_dtype_on_disk num_bytes_stop = item_positions[-1][0] + item_positions[-1][1] num_bytes_start = item_positions[0][0] length_in_bytes = num_bytes_stop - num_bytes_start num_tokens = length_in_bytes // self._token_size_in_bytes tokens = np.frombuffer( buffer=self._embedded_stream_data.data, dtype=self._token_dtype_on_disk, count=num_tokens, offset=num_bytes_start, ) # torch can't convert most uint-formats, therefore we infer regular int types tokens = tokens.astype(self._token_dtype_in_ram) documents = [] for offset_in_bytes, length_in_bytes in item_positions: token_start = (offset_in_bytes - num_bytes_start) // self._token_size_in_bytes token_end = (offset_in_bytes + length_in_bytes - num_bytes_start) // self._token_size_in_bytes documents.append(tokens[token_start:token_end]) # TODO: the return type is inconsistent here. # If idx is an integer, we return a BatchEncoding with a single document. # If idx is a slice, we return a BatchEncoding with a list of documents. if not isinstance(idx, slice): return BatchEncoding(data={self.sample_key: documents[0]}) else: return BatchEncoding(data={self.sample_key: documents})
[docs] class PackedMemMapDatasetContinuous(PackedMemMapDatasetBase): """PackedMemMapDatasetContinuous class.""" def __init__(self, raw_data_path: Path, sample_key: str, block_size: int, load_index: Optional[bool] = False): """ Initializes the PackedMemMapDatasetContinuous object. Args: raw_data_path (Path): Path to a packed binary file (*.pbin). Use `modalities data pack_encoded_data` to create one based on a JSONL-file. sample_key (str): The key to access the sample in the BatchEncoding. block_size (int): The size of the block. load_index (bool, optional): Flag indicating whether to load the index. This is only needed for debugging purposes to index the original documents. The continuous packing does not need to load the index and should be deactivated as it significantly increases the instantiation time. Defaults to False. Returns: None """ self.block_size = block_size # TODO passing the load_index flag does not really comply with the inversion # of control principle. We should refactor this in the future. super().__init__(raw_data_path=raw_data_path, sample_key=sample_key, load_index=load_index) @staticmethod def _create_packed_index(total_tokens: int, block_size: int, token_size_in_bytes: int) -> list[tuple[int, int]]: # Given a fixed number of samples we can compute the total number of tokens as # num_tokens = block_size + (block_size-1) * (num_samples-1) # as the first sample always needs block_size many tokens and the following samples # each need block_size-1 many tokens (since we can reuse the last target token as the first input token # of the subsequent sample). num_samples = (total_tokens - block_size) // (block_size - 1) + 1 # create an index array of the form [0, 1, 2, ..., num_samples-1] i_array = np.arange(num_samples) # Vectorized operations # create the starting byte position of each sample first_component = (i_array * block_size - i_array) * token_size_in_bytes # create the second component, which is the length of each sample in bytes second_component = np.full(num_samples, block_size * token_size_in_bytes) # Combine both components into a 2D array of tuples (or list of tuples if needed) result = np.stack((first_component, second_component), axis=1) return result def _generate_packing_index(self) -> list[tuple[int, int]]: # Generates the packing index for the dataset. # A list of tuples representing the index, where each tuple contains the offset and length in bytes. # get number of total tokens in file total_tokens = self._embedded_stream_data.data_len // self._token_size_in_bytes if total_tokens < self.block_size: raise ValueError( f"Block size ({self.block_size}) is larger than the" "total number of tokens in the dataset ({total_tokens})." ) if self.block_size < 2: raise ValueError("Block size must be at least 2.") result = self._create_packed_index(total_tokens, self.block_size, self._token_size_in_bytes) return result
[docs] class PackedMemMapDatasetMegatron(PackedMemMapDatasetBase): def __init__(self, raw_data_path: Path, sample_key: str, block_size: int): self.block_size = block_size super().__init__(raw_data_path=raw_data_path, sample_key=sample_key) def _generate_packing_index(self) -> list[tuple[int, int]]: index = [] curr_offset = self.HEADER_SIZE_IN_BYTES curr_len = 0 block_size_in_bytes = self.block_size * self._token_size_in_bytes for segment_offset, segment_len in tqdm(self._embedded_stream_data.index_base): # When the sum of the length of the current previously seen samples doesn't # exceed block_size_in_bytes, we add the current segment length to the previous # ones and continue. if curr_len + segment_len < block_size_in_bytes: curr_len += segment_len # If the previous and current length equals block_size_in_bytes, we add the starting index # and the total sequences length to the index list as a new sample. elif curr_len + segment_len == block_size_in_bytes: index.append((curr_offset, self.block_size)) curr_len = 0 curr_offset += block_size_in_bytes # Else case is executed when the current and previous segment length exceed the block_size. # In this case we set the starting point of the next sample to the end of the current sample. # This way, the start of a sample is never in the middle of a sentence. else: index.append((curr_offset, self.block_size)) if segment_len > block_size_in_bytes: curr_offset += block_size_in_bytes curr_len = 0 else: curr_offset = segment_offset curr_len = segment_len return index
[docs] class CombinedDataset(Dataset): """Combines multiple datasets into one large dataset at runtime. Note: When using this class to combine multiple `PackedMemMapDataset`s, each packed sample is packed from a single dataset (i.e., the samples are not mixed between datasets). In the Dataloader, a batch will still contain packed samples from different datasets. """ def __init__(self, datasets: list[Dataset]): """Initializes the CombinedDataset object, combining multiple datasets. Args: datasets (list[Dataset]): A list of datasets to combine. """ self.datasets = datasets self.cumulative_sizes = np.cumsum([len(ds) for ds in datasets], dtype=np.int64) def __len__(self) -> int: return self.cumulative_sizes[-1] def __getitem__(self, idx: int) -> dict: dataset_idx = np.searchsorted(self.cumulative_sizes, idx, side="right") local_idx = idx - (self.cumulative_sizes[dataset_idx - 1] if dataset_idx > 0 else 0) return self.datasets[dataset_idx][local_idx]