Monkey patching a PyTorch model
Well, ideally you should not do this.
But sometimes you just need to quickly test something and amend your model on the fly.
This helps:
The above code essentially does the same as:
.path.to.some.block = some_other_block
#python
#pytorch
#deep_learning
#oop
Well, ideally you should not do this.
But sometimes you just need to quickly test something and amend your model on the fly.
This helps:
import torch
import functools
def rsetattr(obj, attr, val):
pre, _, post = attr.rpartition('.')
return setattr(rgetattr(obj, pre) if pre else obj, post, val)
def rgetattr(obj, attr, *args):
def _getattr(obj, attr):
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split('.'))
for module in model.named_modules():
old_module_path = module[0]
old_module_object = module[1]
# replace an old object with the new one
# copy some settings and its state
if isinstance(old_module_object,torch.nn.SomeClass):
new_module = SomeOtherClass(old_module_object.some_settings,
old_module_object.some_other_settings)
new_module.load_state_dict(module_object.state_dict())
rsetattr(model,old_module_path,new_module)
The above code essentially does the same as:
model
.path.to.some.block = some_other_block
`
#python
#pytorch
#deep_learning
#oop