Spark in me
2.21K subscribers
822 photos
48 videos
116 files
2.68K links
Lost like tears in rain. DS, ML, a bit of philosophy and math. No bs or ads.
Download Telegram
Monkey patching a PyTorch model

Well, ideally you should not do this.
But sometimes you just need to quickly test something and amend your model on the fly.

This helps:


import torch
import functools

def rsetattr(obj, attr, val):
pre, _, post = attr.rpartition('.')
return setattr(rgetattr(obj, pre) if pre else obj, post, val)

def rgetattr(obj, attr, *args):
def _getattr(obj, attr):
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split('.'))

for module in model.named_modules():
old_module_path = module[0]
old_module_object = module[1]
# replace an old object with the new one
# copy some settings and its state
if isinstance(old_module_object,torch.nn.SomeClass):
new_module = SomeOtherClass(old_module_object.some_settings,
old_module_object.some_other_settings)

new_module.load_state_dict(module_object.state_dict())
rsetattr(model,old_module_path,new_module)


The above code essentially does the same as:


model

.path.to.some.block = some_other_block
`

#python
#pytorch
#deep_learning
#oop