class LossContainer: def __init__(self, weight, loss): self.weight = weight self.loss = loss def __call__(self, *args, **kwargs): if self.weight > 0: return self.weight * self.loss(*args, **kwargs) return 0.0