This repository has been archived on 2025-12-23. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
2025ML-project-neural_compr.../integer_discrete_flows/models/utils.py
2025-11-07 12:54:36 +01:00

36 lines
915 B
Python

import torch
class Base(torch.nn.Module):
"""
The base class for modules. That contains a disable round mode
"""
def __init__(self):
super().__init__()
def _set_child_attribute(self, attr, value):
r"""Sets the module in rounding mode.
This has any effect only on certain modules if variable type is
discrete.
Returns:
Module: self
"""
if hasattr(self, attr):
setattr(self, attr, value)
for module in self.modules():
if hasattr(module, attr):
setattr(module, attr, value)
return self
def set_temperature(self, value):
self._set_child_attribute("temperature", value)
def enable_hard_round(self, mode=True):
self._set_child_attribute("hard_round", mode)
def disable_hard_round(self, mode=True):
self.enable_hard_round(not mode)