feat: initial for IDF

This commit is contained in:
Robin Meersman 2025-11-07 12:54:36 +01:00
commit ef4684ef39
27 changed files with 2830 additions and 0 deletions

View file

@ -0,0 +1,36 @@
import torch
class Base(torch.nn.Module):
"""
The base class for modules. That contains a disable round mode
"""
def __init__(self):
super().__init__()
def _set_child_attribute(self, attr, value):
r"""Sets the module in rounding mode.
This has any effect only on certain modules if variable type is
discrete.
Returns:
Module: self
"""
if hasattr(self, attr):
setattr(self, attr, value)
for module in self.modules():
if hasattr(module, attr):
setattr(module, attr, value)
return self
def set_temperature(self, value):
self._set_child_attribute("temperature", value)
def enable_hard_round(self, mode=True):
self._set_child_attribute("hard_round", mode)
def disable_hard_round(self, mode=True):
self.enable_hard_round(not mode)