From 023c06d476b73083443dcfd5db241fa435783452 Mon Sep 17 00:00:00 2001 From: Johnny Slos Date: Sun, 23 Mar 2025 09:43:50 +0100 Subject: [PATCH] feat: now supports already hot encoded labels --- balanced_loss/losses.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/balanced_loss/losses.py b/balanced_loss/losses.py index eceab97..685ba54 100644 --- a/balanced_loss/losses.py +++ b/balanced_loss/losses.py @@ -92,7 +92,13 @@ def forward(self, logits: torch.tensor, labels: torch.tensor): batch_size = logits.size(0) num_classes = logits.size(1) - labels_one_hot = F.one_hot(labels, num_classes).float() + + if labels.ndim == 2 and labels.shape[1] == num_classes: + # Assuming labels are already one-hot + labels_one_hot = labels.float() # Ensure it's float, if needed + else: + # Perform one-hot encoding + labels_one_hot = F.one_hot(labels, num_classes).float() if self.class_balanced: effective_num = 1.0 - np.power(self.beta, self.samples_per_class)