> ## Documentation Index
> Fetch the complete documentation index at: https://wb-21fd5541-docs-sandboxes-integrations-placement.mintlify.site/llms.txt
> Use this file to discover all available pages before exploring further.

# Tutorial: Train a PyTorch model

> Learn how to train a PyTorch model in a Serverless Sandbox environment with this step-by-step tutorial.

<Warning>
  Serverless Sandboxes is in public preview.
</Warning>

In this tutorial, you train a PyTorch model in a Serverless Sandbox environment. To do this, you start a sandbox with the appropriate environment variables, install the necessary dependencies, and run a Python script. The script trains a neural network on the UCI Zoo dataset.

By the end of this tutorial, you have a trained PyTorch model file saved locally. This demonstrates how to use a Serverless Sandbox to run isolated ML training workloads without configuring local infrastructure. This tutorial is intended for ML practitioners and developers who want to evaluate Serverless Sandboxes for reproducible training jobs.

## Prerequisites

Before you get started, complete the following setup steps.

### Install the W\&B Python SDK

The W\&B Python SDK provides the `Sandbox` interface you use later to create and interact with the Serverless Sandbox. Install it using `pip`:

```bash theme={null}
pip install wandb
```

### Log in and authenticate with W\&B

W\&B Serverless Sandboxes run under your W\&B account, so you must authenticate before you create one. Use the `wandb login` CLI command and follow the prompts to log in:

```bash theme={null}
wandb login
```

See the [`wandb login`](/models/ref/cli/wandb-login) reference documentation for more information about how W\&B searches for credentials.

## Copy the training script and dependencies

Prepare the three files required for this tutorial: a requirements file, a hyperparameters file, and a training script. Expand the following dropdown, then copy each code sample into a separate file in the same directory as this tutorial.

In the next section, you run a script that reads these files and trains a PyTorch model in a W\&B Serverless Sandbox.

<Accordion title="PyTorch training model script">
  Copy and paste the following code into a file named `requirements.txt`. This file contains the dependencies for the training script.

  ```txt title="requirements.txt" theme={null}
  torch
  pandas
  ucimlrepo
  scikit-learn
  pyyaml
  ```

  Copy and paste the following code into a YAML file named `hyperparameters.yaml`. This file contains the hyperparameters for the training script.

  ```yaml title="hyperparameters.yaml" theme={null}
  learning_rate: 0.1
  epochs: 1000
  model_type: Multivariate_neural_network_classifier
  ```

  Copy and paste the following code into a file named `train.py`. This script trains a PyTorch model on the UCI Zoo dataset and saves the trained model to a file named `zoo_wandb.pth`.

  ```python title="train.py" theme={null}
  import argparse
  import torch 
  from torch import nn
  import yaml
  import pandas as pd
  from ucimlrepo import fetch_ucirepo

  from sklearn.model_selection import train_test_split

  class NeuralNetwork(nn.Module):
      def __init__(self):
          super().__init__()
          self.linear_stack = nn.Sequential(
              nn.Linear(in_features=16 , out_features=16),
              nn.Sigmoid(),
              nn.Linear(in_features=16, out_features=7)
          )

      def forward(self, x):
          logits = self.linear_stack(x)
          return logits

  def main(args):
      # Load hyperparameters from the provided config file
      with open(args.config, 'r') as f:
          hyperparameter_config = yaml.safe_load(f)

      # fetch dataset 
      zoo = fetch_ucirepo(id=111) 
      
      # data (as pandas dataframes) 
      X = zoo.data.features 
      y = zoo.data.targets

      print("features: ", X.shape, "type: ", type(X))
      print("labels: ", y.shape, "type: ", type(y))

      ## Process data
      # Data type of the data must match the data type of the model, the default dtype for nn.Linear is torch.float32
      dataset = torch.tensor(X.values).type(torch.float32) 

      # Convert to tensor and format labels from 0 - 6 for indexing
      labels = torch.tensor(y.values)  - 1

      print("dataset: ", dataset.shape, "dtype: ",dataset.dtype)
      print("labels: ", labels.shape, "dtype: ",labels.dtype)

      torch.save(dataset, "zoo_dataset.pt")
      torch.save(labels, "zoo_labels.pt")

      # Describe how we split the training dataset for future reference, reproducibility.
      config = {
          "random_state" : 42,
          "test_size" : 0.25,
          "shuffle" : True
      }

      # Split dataset into training and test set
      X_train, X_test, y_train, y_test = train_test_split(
          dataset,labels, 
          random_state=config["random_state"],
          test_size=config["test_size"], 
          shuffle=config["shuffle"]
      )

      # Save the files locally
      torch.save(X_train, "zoo_dataset_X_train.pt")
      torch.save(y_train, "zoo_labels_y_train.pt")

      torch.save(X_test, "zoo_dataset_X_test.pt")
      torch.save(y_test, "zoo_labels_y_test.pt")


      ## Define model
      model = NeuralNetwork()
      loss_fn = nn.CrossEntropyLoss()
      optimizer = torch.optim.SGD(model.parameters(), lr=hyperparameter_config["learning_rate"])
      print(model)

      # Set initial dummy loss value to compare to in training loop
      prev_best_loss = 1e10 

      # Training loop
      for e in range(hyperparameter_config["epochs"] + 1):
          pred = model(X_train)
          loss = loss_fn(pred, y_train.squeeze(1))
          
          loss.backward()
          optimizer.step()
          optimizer.zero_grad()

          # Checkpoint/save model if loss improves
          if (e % 100 == 0) and (loss <= prev_best_loss):
              print("epoch: ", e, "loss:", loss.item())
          
              # Store new best loss
              prev_best_loss = loss

      print("Saving model...")
      PATH = 'zoo_wandb.pth' 
      torch.save(model.state_dict(), PATH)

  if __name__ == "__main__":
      parser = argparse.ArgumentParser(description="Train a simple neural network on the zoo dataset.")
      parser.add_argument("--config", type=str, required=True, help="Path to the hyperparameter configuration file.")
      args = parser.parse_args()
      main(args)
  ```
</Accordion>

## Create the sandbox and run the training script

With your training files in place, create and manage a W\&B Serverless Sandbox from a single Python script. The following code snippet creates a sandbox, copies the training script and dependencies into it, runs the training script, and downloads the generated model file.

The next section explains the code line by line.

Copy and paste the following code into a Python file and run it. Save it in the same directory as the `train.py`, `requirements.txt`, and `hyperparameters.yaml` files you created in the previous step.

```python Show lines title="train_in_sandbox.py" theme={null}
from pathlib import Path
from wandb.sandbox import Sandbox, NetworkOptions

# Files to mount to the sandbox. Specify the path inside the
# sandbox and the content of each file as bytes as a dictionary
mounted_files = [
    {"mount_path": "train.py", "file_content": Path("train.py").read_bytes()},
    {"mount_path": "requirements.txt", "file_content": Path("requirements.txt").read_bytes()},
        ] 

print("Starting sandbox...")
with Sandbox.run(
    mounted_files=mounted_files,
    container_image="python:3.13",
    network=NetworkOptions(egress_mode="internet"),
    max_lifetime_seconds=3600
) as sandbox:
    sandbox.write_file("hyperparameters.yaml", Path("hyperparameters.yaml").read_bytes()).result()

    # Install dependencies
    print("Installing dependencies...")
    sandbox.exec(["pip", "install", "-r", "requirements.txt"], check=True).result()

    # Run the script
    print("Running script...")
    result = sandbox.exec(["python", "train.py", "--config", "hyperparameters.yaml"]).result()
    print(result.stdout)
    print(result.stderr)
    print(f"Exit code: {result.returncode}")

    # Save the generated model file locally
    print("Downloading zoo_wandb.pth...")
    model_data = sandbox.read_file("zoo_wandb.pth").result()
    Path("zoo_wandb.pth").write_bytes(model_data)
    print("Saved zoo_wandb.pth")
```

The previous code snippet does the following:

1. (Lines 6-9) List the files to mount to the sandbox: `train.py` and `requirements.txt`.
2. (Line 12) Start the sandbox. The sandbox is configured to use the `python:3.13` container image, have internet access, and a maximum lifetime of 3600 seconds (1 hour).
3. (Line 18) Write the `hyperparameters.yaml` file to the sandbox. This lets the training script (`train.py`) access the hyperparameters when it runs.
4. (Line 22) Install dependencies. The command `pip install -r requirements.txt` runs inside the sandbox to install the necessary dependencies for the training script.
5. (Line 26) Run the training script. The command `python train.py --config hyperparameters.yaml` runs inside the sandbox to start the training process. The script trains a PyTorch model on the UCI Zoo dataset and saves the trained model to a file named `zoo_wandb.pth`.
6. (Lines 27-29) Print the output and exit code. After the training script finishes, the standard output, standard error, and exit code are printed to the console for debugging and verification.
7. (Lines 33-34) Download the generated model file. The `read_file()` method reads `zoo_wandb.pth` from the sandbox, and the script saves it locally.

After the script completes, you have a trained PyTorch model saved as `zoo_wandb.pth` in your working directory. The sandbox that produced it is created, used, and torn down on demand.
