Skip to main content

Command Palette

Search for a command to run...

Jane Street puzzle writeup: Robot Baseball

Updated
4 min read
Jane Street puzzle writeup: Robot Baseball
D

contact@korff.dev

Overview:

This puzzle involved calculating the maximum probability q of a game of baseball played between two robots to reach „full-count“ (three balls and two strikes).

At each turn, the two robots (one as the pitcher and the other as the batter) simultaneously make one of two choices:

  • The pitcher can either throw a ball or a strike.

  • The batter can either swing or wait.

A running tally of strikes and balls is also maintained.

This process leads to the following matrix:

Pitcher throws ballPitcher throws strike
Batter swingsstrikes + 1home run with probability h otherwise strikes + 1
Batter waitsballs + 1strikes + 1

A game series concludes if any of the following conditions are met:

  • The count of balls reaches 4 → Results in the batter gaining 1 point

  • The count of strikes reaches 3 → Results in the batter gaining 0 points

  • The batter hits a home run → Results in the batter gaining 4 points

The goal of the batter is to maximize its score, while the pitcher aims to minimize it.

Furthermore, the robots use optimal strategies and h has been set to maximize the value of q.

Approach:

  1. Find the optimal strategy for the bots

  2. Find a function for q that depends on h

  3. Find the maximum value of this function

The bots strategy:

Since there is no trivial best strategy I turned to the Nash equilibrium. Essentially it is about the bots removing any possibility of being exploited by another strategy, that means the bots have to be statistically indifferent to every action of their opponent. The bots do this by choosing actions based on probabilities.

So we just need to compute the expected values of each actions combined with the opponents choice and solve them to be equal.

So let b be the probability of the batter swinging and p be the probability of the pitcher throwing a ball.

But how do we get the expected values?

The expected value change with the game state i.e. if the count of strikes has already reached 2 the expected value would obviously be less than if the count of balls has reached 3 instead.

Let’s start by solving the most trivial case 3 balls and 2 strikes. No matter what the bots chose the game must come to an end and the payoffs are known from the puzzle description. So lets construct a payoff-matrix.

Pitcher throws ball (p)Pitcher throws strike (1-p)
Batter swings (b)0h*4+(1-h)*0\= (h*4)*(1-p)
Batter waits (1-b)10\= p
\= 1*(1-b)\= (h*4)*b

$$4 \cdot h \cdot (1-p) = p \;\qquad 1-b = 4 \cdot h \cdot b$$

$$p = \frac{4h}{1+4h} \qquad b = \frac{1}{1+4h}$$

Now we have function for p and b so we could compute the expected value of this game state dependent on h. However doing this for every game state is quite troublesome so this is a task best left to computers.

If we consider the various game states as nodes in a dependency graph, we get the following structure:

# If you are interested in the code, you will find the full code in the github gist

nodes = []

root = Node(0,0)

def build_graph(root):

  # If one extra ball is possible, create that child node
  if root.b != 3:
    child = Node(root.b + 1, root.s)


    if child not in nodes:
      nodes.append(child)
      root.b1 = child
      build_graph(child)
    else:
      root.b1 = nodes[nodes.index(child)]

  # If one extra strike is possible, create that child node
  if root.s != 2:
    child = Node(root.b, root.s + 1)

    if child not in nodes:
      nodes.append(child)
      root.s1 = child
      build_graph(child)
    else:
      root.s1 = nodes[nodes.index(child)]

nodes.append(root)  
build_graph(root)


print(f"Graph nodes: {len(nodes)}")

Finding h and q

Now that we have p and b for every possible game state we have solved the game.

That alone would hardly be an entertaining viewing experience, so we need to find the value for h maximizing q.

I have done this by mapping every path from the 0,0-node to the 3,2-node.

# If you are interested in the code, you will find the full code in the github gist

Now we have a function giving us q for a given h.

Now we just need to find the maximum within the Interval [0,1].

from scipy.optimize import minimize_scalar

res = minimize_scalar(lambda h: -sum_func(h), bounds=(0,1), method='bounded')

print("h =", res.x)                # h = 0.22697434289547633
print("q =", sum_func(res.x))      # q = 0.2959679933709649

And there we have it.

The maximal value for q is: 0.2959679933

Full code:

View Github Gist