36 lines
915 B
Python
36 lines
915 B
Python
import torch
|
|
|
|
|
|
class Base(torch.nn.Module):
|
|
"""
|
|
The base class for modules. That contains a disable round mode
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def _set_child_attribute(self, attr, value):
|
|
r"""Sets the module in rounding mode.
|
|
|
|
This has any effect only on certain modules if variable type is
|
|
discrete.
|
|
|
|
Returns:
|
|
Module: self
|
|
"""
|
|
if hasattr(self, attr):
|
|
setattr(self, attr, value)
|
|
|
|
for module in self.modules():
|
|
if hasattr(module, attr):
|
|
setattr(module, attr, value)
|
|
return self
|
|
|
|
def set_temperature(self, value):
|
|
self._set_child_attribute("temperature", value)
|
|
|
|
def enable_hard_round(self, mode=True):
|
|
self._set_child_attribute("hard_round", mode)
|
|
|
|
def disable_hard_round(self, mode=True):
|
|
self.enable_hard_round(not mode)
|