본문 바로가기
Tech Development/Computer Vision (PyTorch)

PyTorch: More details about torch.nn.Module & Imports

by JK from Korea 2023. 5. 15.

<PyTorch: More details about torch.nn.Module & Imports> 

 

Date: 2023.05.10            

 

* The PyTorch series will mainly touch on the problem I faced. For actual code, check out my github repository.

 

[torch.nn.Module]

Here is an easy explanation of what torch.nn.Module is and why we import this class (especially __init__) and override certain functions such as the forward().

 

In PyTorch, the torch.nn.Module class is used as a base class for building neural network models. When building a neural network model in PyTorch, we need to subclass the torch.nn.Module class and define the __init__ and forward methods.

 

The torch.nn.Module class is an abstract class that provides several methods and attributes that are useful for building neural network models. The __init__ method is used to define the architecture of the model and initialize its parameters. The forward method is used to define the forward pass of the model, which takes an input tensor and produces an output tensor.

 

The torch.nn.Module class also provides methods for registering and accessing sub-modules of the model, as well as methods for moving the model to a GPU device and saving and loading the model to and from disk.

 

Now let's take a look at the source code of torch.nn.Module and try to simplify what it is doing.

 

The torch.nn.Module class defines the basic structure and behavior of a neural network module in PyTorch. It provides a way to encapsulate a neural network module and to define its forward pass. It also provides a way to register and access sub-modules of the module.

 

The __init__ method of torch.nn.Module initializes the module's state and registers its parameters. It also registers the module with its parent module if it has one.

 

The forward method of torch.nn.Module defines the forward pass of the module. It takes an input tensor and produces an output tensor using the module's parameters and sub-modules.

 

Overall, torch.nn.Module is a powerful and flexible class that provides a foundation for building complex neural network models in PyTorch. It helps to encapsulate the logic and parameters of a neural network module, making it easier to manage and modify.

 

[Import File from Github Repository]

Quick update. The current dataset we are working on is the “make_circles” dataset imported from “sklearn.datasets”. As we did in the previous posts, so far, we built a training loop and tested the accuracy of a neural network with one hidden layer (5 nodes). We used the sigmoid function and round method to convert the logits into prediction labels and used BCEWithLogitsLoss as our loss function, SGD as our optimizer. The train and test results were disappointing. The accuracy was 49% which is no different than random prediction. Thus, we will visualize and analyze the problem. To do so, we need to import a helper function, and I think it would be a good review to go over the code.

 

[Import Github Repo using pathlib]

I think the code is straightforward enough to understand.

728x90
반응형

댓글