| | import difflib |
| |
|
| | import torch |
| |
|
| |
|
| | def get_layer(l_name, library=torch.nn): |
| | """Return layer object handler from library e.g. from torch.nn |
| | |
| | E.g. if l_name=="elu", returns torch.nn.ELU. |
| | |
| | Args: |
| | l_name (string): Case insensitive name for layer in library (e.g. .'elu'). |
| | library (module): Name of library/module where to search for object handler |
| | with l_name e.g. "torch.nn". |
| | |
| | Returns: |
| | layer_handler (object): handler for the requested layer e.g. (torch.nn.ELU) |
| | |
| | """ |
| |
|
| | all_torch_layers = [x for x in dir(torch.nn)] |
| | match = [x for x in all_torch_layers if l_name.lower() == x.lower()] |
| | if len(match) == 0: |
| | close_matches = difflib.get_close_matches( |
| | l_name, [x.lower() for x in all_torch_layers] |
| | ) |
| | raise NotImplementedError( |
| | "Layer with name {} not found in {}.\n Closest matches: {}".format( |
| | l_name, str(library), close_matches |
| | ) |
| | ) |
| | elif len(match) > 1: |
| | close_matches = difflib.get_close_matches( |
| | l_name, [x.lower() for x in all_torch_layers] |
| | ) |
| | raise NotImplementedError( |
| | "Multiple matchs for layer with name {} not found in {}.\n " |
| | "All matches: {}".format(l_name, str(library), close_matches) |
| | ) |
| | else: |
| | |
| | layer_handler = getattr(library, match[0]) |
| | return layer_handler |