All posts

I Tried Using DL To Optimize Python’s Garbage Collection (1/2)

How I Got This Idea

One of my friends was building an RL based system to train bipedal walker to … well walk. Ideally, one would use libraries like Gym, NEAT, etc but he was stubborn and wanted to build everything from scratch. Anyhow, during that process, he used things like the foot’s angular momentum, hand’s velocity, torso’s energy and whatnot to feed into a neural network to see how it performs. That’s when it struck me : why not use a similar approach to something low level ? Now, I haven’t yet gone through CPython’s internal implementation of these things so for the time being I’m gonna stick to the Python world. This allows me to treat GC as something I can influence externally rather than modify internally.


Repo for this whole thing along with detailed explanations, architecture diagrams, usage, etc can be found here

Why ? Isn’t GC Related Issues Prevalent In Java ?

True. GC pauses can cause memory spikes which can result in OOM errors and it can also have other performance related issues. In fact, Java was going to be my language of choice for this specific experiment because a LOT of things can be tuned in Java’s GC … heck even the GC itself can be switched with multiple implementations.


And even from a machine learning standpoint, that’s nice right ? More knobs mean more dimensions for the model to explore and optimize … so why not use Java ? Well, the reason is simple. I’m not that proficient in Java 😅


I could use ChatGPT and get a server running and somehow figure out a way to invoke the garbage collector .. but I’ll be using PyTorch for my models and to get that up and running on a Java process seems like a chore. So, for this post, I decided to use Python end-to-end.

Note

Python primarily manages memory using reference counting. This means that when an object’s reference count drops to zero, its memory is freed immediately, without involving the garbage collector at all.

As a result, you can’t trigger meaningful GC activity just by creating and deleting a lot of short-lived objects. To actually exercise Python’s garbage collector, you need to create reference cycles (for example, objects that reference each other), because these won’t be reclaimed by reference counting alone and require an explicit GC pass to clean up.

Alright Okay 😓 What Next ..

Input / Output

Now, for any ML model, the first question is : what are the inputs and outputs ?

Well, output is straightforward as it’s a binary classification problem. If our model predicts a value greater than some threshold, trigger the gc call using gc.collect().


The inputs are ... noisier to say the least. Since I’m not directly tapping into Python’s internal GC state, I use process-level metrics :

Note

Choosing system metrics isn’t perfectly accurate because the “system” is running other processes as well. For example, 50% RAM usage doesn’t necessarily mean your application alone is using 50%.

To reduce this noise, I rely on process-level metrics instead of global system metrics.

 1
 2@dataclass
 3class ProfileMetrics:
 4    time : float = 0.0
 5    cpu : float = 0.0
 6    mem : float = 0.0
 7    disk_read : float = 0.0
 8    disk_write : float = 0.0
 9    net_sent : float = 0.0
10    net_recv : float = 0.0
11    rps : float = 0.0
12    p95 : float = 0.0
13    p99 : float = 0.0
14    gc_triggered : bool = False

The output is a single floating-point value between 0 and 1, which represents the model’s confidence that a garbage collection cycle should be triggered :

 1
 2mem_pressure = df["mem"].values / 100.0
 3cpu_factor = df["cpu"].values / 100.0
 4gc_recent = df["gc_triggered"].astype(float).values
 5
 6# Our Output
 7self.targets = np.clip(
 8    0.4 * mem_pressure + 0.3 * cpu_factor + 0.3 * (1 - gc_recent * 0.5),
 9    0.0,
10    1.0,
11).astype(np.float32)
Note

In this implementation, I construct this target using a simple heuristic rather than directly tuning Python’s internal GC parameters. While it is possible to influence GC behavior by adjusting gc.set_threshold(), I deliberately avoid doing so here to keep the experiment simple

Model

Carrying the light hearted spirit of an “MVP”, I chose a dead simple LSTM based model. Since timestamps are involved and the data is sequential, LSTMs are a natural baseline for time-series behavior modeling :

 1class LSTMNetwork(nn.Module):
 2    def __init__(
 3        self,
 4        input_size: int = 10,
 5        hidden_size: int = 64,
 6        num_layers: int = 2,
 7        dropout: float = 0.2,
 8    ):
 9        super().__init__()
10
11        self.input_size = input_size
12        self.hidden_size = hidden_size
13        self.num_layers = num_layers
14
15        self.lstm = nn.LSTM(
16            input_size=input_size,
17            hidden_size=hidden_size,
18            num_layers=num_layers,
19            batch_first=True,
20            dropout=dropout if num_layers > 1 else 0,
21        )
22
23        self.fc = nn.Linear(hidden_size, 1)
24        self.sigmoid = nn.Sigmoid()
25
26    def forward(self, x: torch.Tensor) -> torch.Tensor:
27        lstm_out, _ = self.lstm(x)
28        last_output = lstm_out[:, -1, :]
29        out = self.fc(last_output)
30        out = self.sigmoid(out)
31        return out

Optimizers and Training

There are two main approaches here :

  • Offline training : Profile an application under load, collect data, train a model, and then deploy the trained model alongside the application.
  • Online training : Continuously update the model as new data comes in, with no fixed dataset.

I chose the offline approach. I wanted to see how a model trained specifically on a given application’s behavior would perform when reintroduced into that same environment. For optimization, I used Adam with a learning rate of 1e-3 to keep things simple and stable.

Environment and Training

For the environment, I use a FastAPI based web server with the following endpoints to simulate the respective load-type : (these do things like calculate a bunch of prime numbers, open / read / close files, etc.)

1@app.get("/cpu-heavy")
2@app.get("/memory-heavy")
3@app.get("/network-heavy")
4@app.get("/io-heavy")

In addition to that, it also contains relevant code for running the profiler in the background and exposes an API to get the process’s metrics :

1@app.get("/metrics")
2async def get_metrics():
3    metrics = profiler.get_metrics()
4    data = metrics.to_dict()
5    # .. other processing ..
6    return data

The full code for these can be obtained here. For testing the server under low and high load, I’m using Locust. While training the ML model, I run the FastAPI server without NeuroGC (this is what I call it 🤓) .. load test it against Locust and profile relevant metrics from the server. Once that is done, train the model based on those metrics using a straight-forward torch training loop. After training is done, I save the model as a .pth file (for loading the model later)

Testing The Model

This involves a few things.

  • First, I create anotherFastAPI based web server (similar to the one mentioned above), but when it starts, I load the model and run NeuroGC in the background using asyncio.create_task() method. This triggers (or does not trigger) GC based on the model’s prediction

  • Second, we need a way to see the metrics update in realtime. There are 2 types of metrics that I’m looking for. One is process-level metric like CPU, RAM, etc and the other is server performance metrics like p95, p99, etc. We already get the process metrics via the /metrics endpoint. To get the other ones, we can use Locust as that is already making a ton of requests. We need to club these 2 together so that we can show all these in the UI. The glue code for that is a simple class that runs in the background of the Locust’s process :

 1class ProfileCollector:
 2    def __init__(
 3        self,
 4        profile_interval: float = 1.0,
 5        metrics_server_url: str = "http://localhost:8003",
 6    ):
 7        self.profile_interval = profile_interval
 8        self.metrics_server_url = metrics_server_url
 9
10        ...
11    
12    ...
13
14    def _fetch_server_metrics(self) -> None:
15        try:
16            with httpx.Client(timeout=2.0) as client:
17                try:
18                    resp = client.get(f"{self.server_with_gc_url}/metrics")
19                    if resp.status_code == 200:
20                        self._server_metrics_with_gc = resp.json()
21                except Exception:
22                    pass
23
24                try:
25                    resp = client.get(f"{self.server_without_gc_url}/metrics")
26                    if resp.status_code == 200:
27                        self._server_metrics_without_gc = resp.json()
28                except Exception:
29                    pass
30        except Exception:
31            pass
32
33    def _run_loop(self) -> None:
34        while self._running:
35            try:
36                # Fetch actual metrics from each server
37                self._fetch_server_metrics()
38
39                metrics_with_gc = self._get_metrics_for_server("with_gc")
40                metrics_without_gc = self._get_metrics_for_server("without_gc")
41
42                # Send the metrics to metrics_server
43                try:
44                    with httpx.Client(timeout=5.0) as client:
45                        client.post(
46                            f"{self.metrics_server_url}/api/metrics",
47                            json=metrics_with_gc,
48                        )
49                        client.post(
50                            f"{self.metrics_server_url}/api/metrics",
51                            json=metrics_without_gc,
52                        )
53                except Exception:
54                    pass
55
56            except Exception as e:
57                print(f"[ProfileCollector] Error: {e}")
58
59            time.sleep(self.profile_interval)
60
61    def start(self) -> None:
62        self._running = True
63
64        if self.profiler:
65            self.profiler.start()
66
67        self._thread = threading.Thread(
68            target=self._run_loop,
69            daemon=True,
70        )
71        self._thread.start()
72
73        print(
74            f"[ProfileCollector] Started. Posting to {self.metrics_server_url}"
75        )

I use the @events.init.add_listener method and start the ProfileCollector on on_locust_init hook.

  • Third, for seeing the metrics update in realtime, I didn’t want to involve a huge library like react for this project so I built a simple UI using HTML, Websockets and Charts.js. The backend for this is yet another webserver called the metrics_server . It :
    • Serves the UI
    • Gets the metrics from Locust server via the /api/metrics endpoint
    • Streams the latest metrics to the UI using a Websocket connection

… Too many servers, I know. I apologize 😭

Yeah .. 😭😩 (click to enlarge ... actually don't)
Yeah .. 😭😩 (click to enlarge ... actually don't)

Anyhow, with that setup, we first start the servers with and without NeuroGC on 2 different terminals, then we start the metrics_server on another terminal and finally, start the Locust server on yet another terminal. The UI for the real-time dashboard can be viewed on localhost:8003

Note

This might seem like a chore to do every time (if you are running this locally) …. I would suggest taking a look at tmux … maybe that will ease the process

Benchmarks

System Information

PropertyValue
Operating SystemmacOS 14.6
Architecturearm64
CPUarm
CPU Cores8 (logical: 8)
Memory24.0 GB
Disk460.4 GB
Python Version3.14.0

Light load

  • Training Load : locust -f locustfile.py --headless -u 100 -r 10 --run-time 1m
  • Evaluation Load : locust -f locustfile.py --headless -u 100 -r 10 --run-time 1m

Dataset and Visualizations

With and without NeuroGC on low load (click to enlarge)
With and without NeuroGC on low load (click to enlarge)

Full data :

MetricWithout NeuroGCWith NeuroGCImprovement
Avg CPU (%)38.029.0🟢 +23.7%
Avg Memory (%)50.150.10.0%
Avg Disk Read14410.7320069.47🔴 -39.3%
Avg Disk Write9903865.306187022.12🟢 +37.5%
Avg Net Sent81743.2587163.18🔴 -6.6%
Avg Net Recv102192.2199026.00🟢 +3.1%
P95 Latency (ms)3538.83609.9🔴 -2.0%
P99 Latency (ms)4530.34769.2🔴 -5.3%
Avg RPS31.230.7🔴 -1.7%
GC Events1917🔴 -10.5%

High Load

  • Training Load : locust -f locustfile.py --headless -u 500 -r 10 --run-time 1m
  • Evaluation Load : locust -f locustfile.py --headless -u 500 -r 10 --run-time 1m

Dataset and Visualizations

With and without NeuroGC on high load (click to enlarge)
With and without NeuroGC on high load (click to enlarge)

Full data :

MetricWithout NeuroGCWith NeuroGCImprovement
Avg CPU (%)40.243.2🔴 -7.4%
Avg Memory (%)53.553.50.0%
Avg Disk Read10371.804011.52🟢 +61.3%
Avg Disk Write6082227.945377198.66🟢 +11.6%
Avg Net Sent68960.0062821.57🟢 +8.9%
Avg Net Recv78285.4568550.20🟢 +12.4%
P95 Latency (ms)4297.54424.6🔴 -3.0%
P99 Latency (ms)9652.68373.4🟢 +13.3%
Avg RPS29.528.6🔴 -3.1%
GC Events15150.0%

Conclusion

With this LSTM-based approach, NeuroGC didn’t deliver a dramatic, across-the-board improvement and - that’s honestly not surprising.


What this experiment did show is that application behavior can be learned in a meaningful way using only high-level process metrics. Even with a relatively small dataset and a simple sequence model, the system was able to influence GC timing and, in some cases, reduce disk and network pressure under medium load.


That alone suggests there’s signal here … it’s just not being fully captured yet. In many ways, this LSTM model should be treated as a baseline rather than a final solution.


There are a few obvious directions for improvement :

  • Larger datasets : Extending the load test duration from 1 minute to 10–30 minutes would significantly improve temporal coverage and reduce noise in the training data.
  • Richer feature engineering : Adding derived signals (rates of change, rolling averages, memory allocation deltas, request burstiness) may give the model a clearer picture of why pressure is building.
  • Stronger model architectures : Transformers, temporal convolutional networks, or even reinforcement learning could be better suited to learning long-term effects and delayed rewards in GC behavior.
  • Online or hybrid training : Allowing the model to adapt as the application evolves might make it more robust in real-world deployments.

So while the LSTM didn’t didn’t significantly outperform Python's GC, it did something arguably more important :

It validated the idea that garbage collection can be treated as a learnable system rather than a fixed heuristic.


In Part 2 / 2, I’ll explore more expressive models, longer training runs, and whether this approach can move from “interesting experiment” to something that actually holds up under sustained production-level load. Wish me luck ! 😁🤞

More like this

Comments