diff --git a/litgpt/data/flan.py b/litgpt/data/flan.py index 975894aaa3..509226d372 100644 --- a/litgpt/data/flan.py +++ b/litgpt/data/flan.py @@ -44,7 +44,7 @@ class FLAN(DataModule): train_dataset: SFTDataset | None = field(default=None, init=False, repr=False) test_dataset: SFTDataset | None = field(default=None, init=False, repr=False) - def __post_init__(self): + def __post_init__(self) -> None: super().__init__() if isinstance(self.prompt_style, str): self.prompt_style = PromptStyle.from_name(self.prompt_style) @@ -73,10 +73,10 @@ def prepare_data(self) -> None: data_file_url = f"{self.url}/{split}/{subset}_{split}.jsonl" download_if_missing(data_file_path, data_file_url) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: return self._dataloader("train") - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: return self._dataloader("test") def _dataloader(self, split: str) -> DataLoader: