<?xml version="1.0" encoding="utf-8"?> 
<feed xmlns="http://www.w3.org/2005/Atom" xml:lang="en-us">
    <generator uri="https://gohugo.io/" version="0.152.2">Hugo</generator><title type="html"><![CDATA[PyTrees on Blog]]></title>
    
    
    
            <link href="https://blog.scientific-python.org/tags/pytrees/" rel="alternate" type="text/html" title="html" />
            <link href="https://blog.scientific-python.org/tags/pytrees/atom.xml" rel="self" type="application/atom" title="atom" />
    <updated>2026-04-04T04:32:36+00:00</updated>
    
    
    
    
        <id>https://blog.scientific-python.org/tags/pytrees/</id>
    
        
        <entry>
            <title type="html"><![CDATA[Pytrees for Scientific Python]]></title>
            <link href="https://blog.scientific-python.org/pytrees/?utm_source=atom_feed" rel="alternate" type="text/html" />
            
            
                <id>https://blog.scientific-python.org/pytrees/</id>
            
            
            <published>2025-07-08T00:00:00+00:00</published>
            <updated>2025-07-08T00:00:00+00:00</updated>
            
            
            <content type="html"><![CDATA[<blockquote>Introducing PyTrees for Scientific Python. We discuss what PyTrees are, how they&rsquo;re useful in the realm of scientific Python, and how to work <em>efficiently</em> with them.</blockquote><h2 id="manipulating-tree-like-data-using-functional-programming-paradigms">Manipulating Tree-like Data using Functional Programming Paradigms<a class="headerlink" href="#manipulating-tree-like-data-using-functional-programming-paradigms" title="Link to this heading">#</a></h2>
<p>A &ldquo;PyTree&rdquo; is a nested collection of Python containers (e.g. dicts, (named) tuples, lists, &hellip;), where the leaves are of interest.
In the scientific world, such a PyTree could consist of experimental measurements of different properties at different timestamps and measurement settings resulting in a highly complex, nested and not necessarily rectangular data structure.
Such collections can be cumbersome to manipulate <em>efficiently</em>, especially if they are nested any depth.
It often requires complex recursive logic which usually does not generalize to other nested Python containers (PyTrees), e.g. for new measurements.</p>
<p>The core concept of PyTrees is being able to flatten them into a flat collection of leaves and a &ldquo;blueprint&rdquo; of the tree structure, and then being able to unflatten them back into the original PyTree.
This allows for the application of generic transformations.
In this blog post, we use <a href="https://github.com/metaopt/optree/tree/main/optree"><code>optree</code></a> — a standalone PyTree library — that enables these transformations. It focuses on performance, is feature rich, has minimal dependencies, and has been adopted by <a href="https://pytorch.org">PyTorch</a>, <a href="https://keras.io">Keras</a>, and <a href="https://github.com/tensorflow/tensorflow">TensorFlow</a> (through Keras) as a core dependency.
For example, on a PyTree with NumPy arrays as leaves, taking the square root of each leaf with <code>optree.tree_map(np.sqrt, tree)</code>:</p>


<div class="highlight">
  <pre class="chroma"><code><span class="line"><span class="cl"><span class="kn">import</span> <span class="nn">optree</span> <span class="k">as</span> <span class="nn">pt</span>
</span></span><span class="line"><span class="cl"><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c1"># tuple of a list of a dict with an array as value, and an array</span>
</span></span><span class="line"><span class="cl"><span class="n">tree</span> <span class="o">=</span> <span class="p">([[{</span><span class="s2">&#34;foo&#34;</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">4.0</span><span class="p">])}],</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">9.0</span><span class="p">])],)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c1"># sqrt of each leaf array</span>
</span></span><span class="line"><span class="cl"><span class="n">sqrt_tree</span> <span class="o">=</span> <span class="n">pt</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">sqrt</span><span class="p">,</span> <span class="n">tree</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&#34;</span><span class="si">{</span><span class="n">sqrt_tree</span><span class="si">=}</span><span class="s2">&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="c1"># &gt;&gt; sqrt_tree=([[{&#39;foo&#39;: array([2.])}], array([3.])],)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c1"># reductions</span>
</span></span><span class="line"><span class="cl"><span class="n">all_positive</span> <span class="o">=</span> <span class="nb">all</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">all</span><span class="p">(</span><span class="n">x</span> <span class="o">&gt;</span> <span class="mf">0.0</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">pt</span><span class="o">.</span><span class="n">tree_iter</span><span class="p">(</span><span class="n">tree</span><span class="p">))</span>
</span></span><span class="line"><span class="cl"><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&#34;</span><span class="si">{</span><span class="n">all_positive</span><span class="si">=}</span><span class="s2">&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="c1"># &gt;&gt; all_positive=True</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">summed</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">pt</span><span class="o">.</span><span class="n">tree_reduce</span><span class="p">(</span><span class="nb">sum</span><span class="p">,</span> <span class="n">tree</span><span class="p">))</span>
</span></span><span class="line"><span class="cl"><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&#34;</span><span class="si">{</span><span class="n">summed</span><span class="si">=}</span><span class="s2">&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="c1"># &gt;&gt; summed=np.float64(13.0)</span></span></span></code></pre>
</div>
<p>The trick here is that these operations can be implemented in three steps, e.g. <code>tree_map</code>:</p>


<div class="highlight">
  <pre class="chroma"><code><span class="line"><span class="cl"><span class="c1"># step 1:</span>
</span></span><span class="line"><span class="cl"><span class="n">leaves</span><span class="p">,</span> <span class="n">treedef</span> <span class="o">=</span> <span class="n">pt</span><span class="o">.</span><span class="n">tree_flatten</span><span class="p">(</span><span class="n">tree</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c1"># step 2:</span>
</span></span><span class="line"><span class="cl"><span class="n">new_leaves</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="n">fun</span><span class="p">,</span> <span class="n">leaves</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c1"># step 3:</span>
</span></span><span class="line"><span class="cl"><span class="n">result_tree</span> <span class="o">=</span> <span class="n">pt</span><span class="o">.</span><span class="n">tree_unflatten</span><span class="p">(</span><span class="n">treedef</span><span class="p">,</span> <span class="n">new_leaves</span><span class="p">)</span></span></span></code></pre>
</div>
<h3 id="pytree-origins">PyTree Origins<a class="headerlink" href="#pytree-origins" title="Link to this heading">#</a></h3>
<p>Originally, the concept of PyTrees was developed by the <a href="https://docs.jax.dev/en/latest/">JAX</a> project to make nested collections of JAX arrays work transparently at the &ldquo;JIT-boundary&rdquo; (the JAX JIT toolchain does not know about Python containers, only about JAX Arrays).
However, PyTrees were quickly adopted by AI researchers for broader use-cases: semantically grouping layers of weights and biases in a list of named tuples (or dictionaries) is a common pattern in the JAX-AI-world, as shown in the following (pseudo) Python snippet:</p>


<div class="highlight">
  <pre class="chroma"><code><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">NamedTuple</span><span class="p">,</span> <span class="n">Callable</span>
</span></span><span class="line"><span class="cl"><span class="kn">import</span> <span class="nn">jax</span>
</span></span><span class="line"><span class="cl"><span class="kn">import</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Layer</span><span class="p">(</span><span class="n">NamedTuple</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">W</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span>
</span></span><span class="line"><span class="cl">    <span class="n">b</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">layers</span> <span class="o">=</span> <span class="p">[</span>
</span></span><span class="line"><span class="cl">    <span class="n">Layer</span><span class="p">(</span><span class="n">W</span><span class="o">=</span><span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="o">...</span><span class="p">),</span> <span class="n">b</span><span class="o">=</span><span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="o">...</span><span class="p">)),</span>  <span class="c1"># first layer</span>
</span></span><span class="line"><span class="cl">    <span class="n">Layer</span><span class="p">(</span><span class="n">W</span><span class="o">=</span><span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="o">...</span><span class="p">),</span> <span class="n">b</span><span class="o">=</span><span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="o">...</span><span class="p">)),</span>  <span class="c1"># second layer</span>
</span></span><span class="line"><span class="cl">    <span class="o">...</span><span class="p">,</span>
</span></span><span class="line"><span class="cl"><span class="p">]</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="nd">@jax.jit</span>
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">neural_network</span><span class="p">(</span><span class="n">layers</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">Layer</span><span class="p">],</span> <span class="n">x</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">    <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="n">layers</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">W</span> <span class="o">@</span> <span class="n">x</span> <span class="o">+</span> <span class="n">layer</span><span class="o">.</span><span class="n">b</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">x</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">prediction</span> <span class="o">=</span> <span class="n">neural_network</span><span class="p">(</span><span class="n">layers</span><span class="o">=</span><span class="n">layers</span><span class="p">,</span> <span class="n">x</span><span class="o">=</span><span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="o">...</span><span class="p">))</span></span></span></code></pre>
</div>
<p>Here, <code>layers</code> is a PyTree — a <code>list</code> of multiple <code>Layer</code> — and the JIT compiled <code>neural_network</code> function <em>just works</em> with this data structure as input.
Although you cannot see what happens inside of <code>jax.jit</code>, <code>layers</code> is automatically flattened by the <code>jax.jit</code> decorator to a flat iterable of arrays, which are understood by the JAX JIT toolchain in contrast to a Python <code>list</code> of <code>NamedTuples</code>.</p>
<h3 id="pytrees-in-scientific-python">PyTrees in Scientific Python<a class="headerlink" href="#pytrees-in-scientific-python" title="Link to this heading">#</a></h3>
<p>Wouldn&rsquo;t it be nice to make workflows in the scientific Python ecosystem <em>just work</em> with any PyTree?</p>
<p>Giving semantic meaning to numeric data through PyTrees can be useful for applications outside of AI as well.
Consider the following minimization of the <a href="https://en.wikipedia.org/wiki/Rosenbrock_function">Rosenbrock</a> function:</p>


<div class="highlight">
  <pre class="chroma"><code><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">scipy.optimize</span> <span class="kn">import</span> <span class="n">minimize</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">rosenbrock</span><span class="p">(</span><span class="n">params</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;&#34;&#34;
</span></span></span><span class="line"><span class="cl"><span class="s2">    Rosenbrock function. Minimum: f(1, 1) = 0.
</span></span></span><span class="line"><span class="cl"><span class="s2">
</span></span></span><span class="line"><span class="cl"><span class="s2">    https://en.wikipedia.org/wiki/Rosenbrock_function
</span></span></span><span class="line"><span class="cl"><span class="s2">    &#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">params</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">x</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">+</span> <span class="mi">100</span> <span class="o">*</span> <span class="p">(</span><span class="n">y</span> <span class="o">-</span> <span class="n">x</span><span class="o">**</span><span class="mi">2</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">x0</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.9</span><span class="p">,</span> <span class="mf">1.2</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">res</span> <span class="o">=</span> <span class="n">minimize</span><span class="p">(</span><span class="n">rosenbrock</span><span class="p">,</span> <span class="n">x0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="nb">print</span><span class="p">(</span><span class="n">res</span><span class="o">.</span><span class="n">x</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="c1"># &gt;&gt; [0.99999569 0.99999137]</span></span></span></code></pre>
</div>
<p>Now, let&rsquo;s consider a minimization that uses a more complex type for the parameters — a NamedTuple that describes our fit parameters:</p>


<div class="highlight">
  <pre class="chroma"><code><span class="line"><span class="cl"><span class="kn">import</span> <span class="nn">optree</span> <span class="k">as</span> <span class="nn">pt</span>
</span></span><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">NamedTuple</span><span class="p">,</span> <span class="n">Callable</span>
</span></span><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">scipy.optimize</span> <span class="kn">import</span> <span class="n">minimize</span> <span class="k">as</span> <span class="n">sp_minimize</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Params</span><span class="p">(</span><span class="n">NamedTuple</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">x</span><span class="p">:</span> <span class="nb">float</span>
</span></span><span class="line"><span class="cl">    <span class="n">y</span><span class="p">:</span> <span class="nb">float</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">rosenbrock</span><span class="p">(</span><span class="n">params</span><span class="p">:</span> <span class="n">Params</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;&#34;&#34;
</span></span></span><span class="line"><span class="cl"><span class="s2">    Rosenbrock function. Minimum: f(1, 1) = 0.
</span></span></span><span class="line"><span class="cl"><span class="s2">
</span></span></span><span class="line"><span class="cl"><span class="s2">    https://en.wikipedia.org/wiki/Rosenbrock_function
</span></span></span><span class="line"><span class="cl"><span class="s2">    &#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">params</span><span class="o">.</span><span class="n">x</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">+</span> <span class="mi">100</span> <span class="o">*</span> <span class="p">(</span><span class="n">params</span><span class="o">.</span><span class="n">y</span> <span class="o">-</span> <span class="n">params</span><span class="o">.</span><span class="n">x</span><span class="o">**</span><span class="mi">2</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">minimize</span><span class="p">(</span><span class="n">fun</span><span class="p">:</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">params</span><span class="p">:</span> <span class="n">Params</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Params</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">    <span class="c1"># flatten and store PyTree definition</span>
</span></span><span class="line"><span class="cl">    <span class="n">flat_params</span><span class="p">,</span> <span class="n">treedef</span> <span class="o">=</span> <span class="n">pt</span><span class="o">.</span><span class="n">tree_flatten</span><span class="p">(</span><span class="n">params</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="c1"># wrap fun to work with flat_params</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">wrapped_fun</span><span class="p">(</span><span class="n">flat_params</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="n">params</span> <span class="o">=</span> <span class="n">pt</span><span class="o">.</span><span class="n">tree_unflatten</span><span class="p">(</span><span class="n">treedef</span><span class="p">,</span> <span class="n">flat_params</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">fun</span><span class="p">(</span><span class="n">params</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="c1"># actual minimization</span>
</span></span><span class="line"><span class="cl">    <span class="n">res</span> <span class="o">=</span> <span class="n">sp_minimize</span><span class="p">(</span><span class="n">wrapped_fun</span><span class="p">,</span> <span class="n">flat_params</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="c1"># re-wrap the bestfit values into Params with stored PyTree definition</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">pt</span><span class="o">.</span><span class="n">tree_unflatten</span><span class="p">(</span><span class="n">treedef</span><span class="p">,</span> <span class="n">res</span><span class="o">.</span><span class="n">x</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c1"># scipy minimize that works with any PyTree</span>
</span></span><span class="line"><span class="cl"><span class="n">x0</span> <span class="o">=</span> <span class="n">Params</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="mf">1.2</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">bestfit_params</span> <span class="o">=</span> <span class="n">minimize</span><span class="p">(</span><span class="n">rosenbrock</span><span class="p">,</span> <span class="n">x0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="nb">print</span><span class="p">(</span><span class="n">bestfit_params</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="c1"># &gt;&gt; Params(x=np.float64(0.999995688776513), y=np.float64(0.9999913673387226))</span></span></span></code></pre>
</div>
<p>This new <code>minimize</code> function works with <em>any</em> PyTree!</p>
<p>Let&rsquo;s now consider a modified and more complex version of the Rosenbrock function that relies on two sets of <code>Params</code> as input — a common pattern for hierarchical models (e.g. a superposition of various probability density functions):</p>


<div class="highlight">
  <pre class="chroma"><code><span class="line"><span class="cl"><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">rosenbrock_modified</span><span class="p">(</span><span class="n">two_params</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">[</span><span class="n">Params</span><span class="p">,</span> <span class="n">Params</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;&#34;&#34;
</span></span></span><span class="line"><span class="cl"><span class="s2">    Modified Rosenbrock where the x and y parameters are determined by
</span></span></span><span class="line"><span class="cl"><span class="s2">    a non-linear transformations of two versions of each, i.e.:
</span></span></span><span class="line"><span class="cl"><span class="s2">      x = arcsin(min(x1, x2) / max(x1, x2))
</span></span></span><span class="line"><span class="cl"><span class="s2">      y = sigmoid(x1 - x2)
</span></span></span><span class="line"><span class="cl"><span class="s2">    &#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="n">p1</span><span class="p">,</span> <span class="n">p2</span> <span class="o">=</span> <span class="n">two_params</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="c1"># calculate `x` and `y` from two sources:</span>
</span></span><span class="line"><span class="cl">    <span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asin</span><span class="p">(</span><span class="nb">min</span><span class="p">(</span><span class="n">p1</span><span class="o">.</span><span class="n">x</span><span class="p">,</span> <span class="n">p2</span><span class="o">.</span><span class="n">x</span><span class="p">)</span> <span class="o">/</span> <span class="nb">max</span><span class="p">(</span><span class="n">p1</span><span class="o">.</span><span class="n">x</span><span class="p">,</span> <span class="n">p2</span><span class="o">.</span><span class="n">x</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">    <span class="n">y</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">np</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="p">(</span><span class="n">p1</span><span class="o">.</span><span class="n">y</span> <span class="o">/</span> <span class="n">p2</span><span class="o">.</span><span class="n">y</span><span class="p">)))</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">x</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">+</span> <span class="mi">100</span> <span class="o">*</span> <span class="p">(</span><span class="n">y</span> <span class="o">-</span> <span class="n">x</span><span class="o">**</span><span class="mi">2</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">x0</span> <span class="o">=</span> <span class="p">(</span><span class="n">Params</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="mf">1.2</span><span class="p">),</span> <span class="n">Params</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="mf">0.8</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="mf">1.3</span><span class="p">))</span>
</span></span><span class="line"><span class="cl"><span class="n">bestfit_params</span> <span class="o">=</span> <span class="n">minimize</span><span class="p">(</span><span class="n">rosenbrock_modified</span><span class="p">,</span> <span class="n">x0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="nb">print</span><span class="p">(</span><span class="n">bestfit_params</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="c1"># &gt;&gt; (</span>
</span></span><span class="line"><span class="cl"><span class="c1">#     Params(x=np.float64(4.686181110201706), y=np.float64(0.05129869722505759)),</span>
</span></span><span class="line"><span class="cl"><span class="c1">#     Params(x=np.float64(3.9432263101976073), y=np.float64(0.005146110126174016)),</span>
</span></span><span class="line"><span class="cl"><span class="c1"># )</span></span></span></code></pre>
</div>
<p>The new <code>minimize</code> still works, because a <code>tuple</code> of <code>Params</code> is just <em>another</em> PyTree!</p>
<h3 id="final-thought">Final Thought<a class="headerlink" href="#final-thought" title="Link to this heading">#</a></h3>
<p>Working with nested data structures doesn’t have to be messy.
PyTrees let you focus on the data and the transformations you want to apply, in a generic manner.
Whether you&rsquo;re building neural networks, optimizing scientific models, or just dealing with complex nested Python containers, PyTrees can make your code cleaner, more flexible, and just nicer to work with.</p>
]]></content>
            
                 
                    
                 
                    
                         
                        
                            
                             
                                <category scheme="taxonomy:Tags" term="pytrees" label="PyTrees" />
                             
                                <category scheme="taxonomy:Tags" term="functional-programming" label="Functional Programming" />
                             
                                <category scheme="taxonomy:Tags" term="tree-like-data-manipulation" label="Tree-like data manipulation" />
                            
                        
                    
                
            
        </entry>
    
</feed>
