``` >> f = x^3 + 2x x^3 + 2 x
>> f'(x) 3 * x^2 + 2
>> f'(x)'(x) 6 * x ```
Note: the simplifier isn't perfect. Some results may be messier than they should. But the core works.
--- ### Representing expressions Everything in this program is an `expr`. A number, a variable, `sin(x)`, `x^2+1`. All the same struct
```c typedef struct expr { exprtype type; int mindepth, maxdepth, nodes; int ref; union { char symbol; double num; struct { struct expr left; struct expr right; }; struct expr unary; }; } expr; ``` The union is the interesting part Depending on the type, an epxression is either: - a leaf: just a number or a symbol (x, 3.14) - a unary: just one child (sin(x), -x, log(x)) - a binary: just two children (x+1, x^2)
So the expression `x^2+1` becomes a tree like this: ``` ADD / \ EXP 1 / \ x 2 ``` Every node is the same struct. You navigate it recursively. This pattern shows up everywhere in the codebase. ---
--- ### Differentiation Symbolic differentiation is just pattern matching on the expression tree. Each node type has a rule, and you apply it recursively. The main function is a switch: ``` expr derive(expr f, expr sym) { switch (f->type) { case EXPR_ADD: return deriveAdd(f, sym); case EXPR_MUL: return deriveMul(f, sym); case EXPR_EXP: return deriveExp(f, sym); case EXPR_SIN: return deriveSin(f, sym); case EXPR_LOG: return deriveLog(f, sym); case EXPR_SYM: return eq(f, sym) ? one : zero; case EXPR_NUM: return zero; // ... } } ``` The base cases are trivial: the derivative of a constant is zero and the derivative of x with respect to x is one (anything else is zero).
The interesting cases are the rules. Here's the product rule: ``` expr deriveMul(expr f, expr sym) { expr left = mul(derive(f->left, sym), retain(f->right)); expr right = mul(retain(f->left), derive(f->right, sym)); return add(left, right); } ``` That's literally `f' * g + f * g'`. The `retain` calls are there because we're sharing the original subtrees, the derivative will own a reference to them. The chain rule falls out naturally too. When you differentiate `sin(x^2)`, `deriveSin` calls `derive` on its argument, which recurses into the `x^2`subtree. You don't write chain rule handling separately, it's just recursion.
--- ### Simplification Differentiation produces correct but messy results. `derive(x^2)` gives you `x^2 * (1 * log(x) + 2 * 1 / x)` before any cleanup. Simplification is a second recursive pass over the tree that applies algebric identities. Some cases are straightforward: ``` x * 0 = 0 1 * x = x x / x = 1 x^1 = x x^0 = 1 ... ``` The simplifier recurses bottom-up: it simplifies the children first, then applies rules to the result. This is also where the "not perfect" disclaimer comes in. Simplification is essentially rewriting, and you can always find expressions that don't reduce as far as they could. Getting it fully right would mean implementing a proper term rewritiing system, which is a rabbit hole I chose not to go down.
--- ### The parser The REPL reads a string and turns it into an expression tree. This is done with hand-written recursive descendant parser, no lexer generator, no parser library.
The grammar looks like this (there's actually a comment in the source): ``` unary = num | sym | "-" unary | "(" expr ")" postfix = unary ("^" unary | "'(" sym ")" | "(" sym "=" num ")")* factor = postfix (("" | "/") postfix) term = factor (("+" | "-") factor)* expr = term | sym "=" expr ``` --- If you want to know more about the implementation you can check the full blog post here: [derive.c](https://marcomit.it/derive.c)