Algorithm Implementation/Viterbi algorithm
The following implementations of the w:Viterbi algorithm were removed from an earlier copy of the Wikipedia page because they were too long and unencyclopaedic - but we hope you'll find them useful here!
Java implementation
import java.util.Hashtable; public class Viterbi { static final String HEALTHY = "Healthy"; static final String FEVER = "Fever"; static final String DIZZY = "dizzy"; static final String COLD = "cold"; static final String NORMAL = "normal"; public static void main(String[] args) { String[] states = new String[] {HEALTHY, FEVER}; String[] observations = new String[] {DIZZY, COLD, NORMAL}; Hashtable<String, Float> start_probability = new Hashtable<String, Float>(); start_probability.put(HEALTHY, 0.6f); start_probability.put(FEVER, 0.4f); // transition_probability Hashtable<String, Hashtable<String, Float>> transition_probability = new Hashtable<String, Hashtable<String, Float>>(); Hashtable<String, Float> t1 = new Hashtable<String, Float>(); t1.put(HEALTHY, 0.7f); t1.put(FEVER, 0.3f); Hashtable<String, Float> t2 = new Hashtable<String, Float>(); t2.put(HEALTHY, 0.4f); t2.put(FEVER, 0.6f); transition_probability.put(HEALTHY, t1); transition_probability.put(FEVER, t2); // emission_probability Hashtable<String, Hashtable<String, Float>> emission_probability = new Hashtable<String, Hashtable<String, Float>>(); Hashtable<String, Float> e1 = new Hashtable<String, Float>(); e1.put(DIZZY, 0.1f); e1.put(COLD, 0.4f); e1.put(NORMAL, 0.5f); Hashtable<String, Float> e2 = new Hashtable<String, Float>(); e2.put(DIZZY, 0.6f); e2.put(COLD, 0.3f); e2.put(NORMAL, 0.1f); emission_probability.put(HEALTHY, e1); emission_probability.put(FEVER, e2); Object[] ret = forward_viterbi(observations, states, start_probability, transition_probability, emission_probability); System.out.println(((Float) ret[0]).floatValue()); System.out.println((String) ret[1]); System.out.println(((Float) ret[2]).floatValue()); } public static Object[] forward_viterbi(String[] obs, String[] states, Hashtable<String, Float> start_p, Hashtable<String, Hashtable<String, Float>> trans_p, Hashtable<String, Hashtable<String, Float>> emit_p) { Hashtable<String, Object[]> T = new Hashtable<String, Object[]>(); for (String state : states) T.put(state, new Object[] {start_p.get(state), state, start_p.get(state)}); for (String output : obs) { Hashtable<String, Object[]> U = new Hashtable<String, Object[]>(); for (String next_state : states) { float total = 0; String argmax = ""; float valmax = 0; float prob = 1; String v_path = ""; float v_prob = 1; for (String source_state : states) { Object[] objs = T.get(source_state); prob = ((Float) objs[0]).floatValue(); v_path = (String) objs[1]; v_prob = ((Float) objs[2]).floatValue(); float p = emit_p.get(source_state).get(output) * trans_p.get(source_state).get(next_state); prob *= p; v_prob *= p; total += prob; if (v_prob > valmax) { argmax = v_path + "," + next_state; valmax = v_prob; } } U.put(next_state, new Object[] {total, argmax, valmax}); } T = U; } float total = 0; String argmax = ""; float valmax = 0; float prob; String v_path; float v_prob; for (String state : states) { Object[] objs = T.get(state); prob = ((Float) objs[0]).floatValue(); v_path = (String) objs[1]; v_prob = ((Float) objs[2]).floatValue(); total += prob; if (v_prob > valmax) { argmax = v_path; valmax = v_prob; } } return new Object[]{total, argmax, valmax}; } }
F# implementation
(* Nick Heiner *) (* Viterbi algorithm, as described here: http://people.ccmr.cornell.edu/~ginsparg/INFO295/vit.pdf priorProbs: prior probability of a hidden state occuring transitions: probability of one hidden state transitioning into another emissionProbs: probability of a hidden state emitting an observed state observation: a sequence of observed states hiddens: a list of all possible hidden states Returns: probability of most likely path * hidden state list representing the path *) let viterbi (priorProbs : 'hidden -> float) (transitions : ('hidden * 'hidden) -> float) (emissionProbs : (('observed * 'hidden) -> float)) (observation : 'observed []) (hiddens : 'hidden list) : float * 'hidden list = (* Referred to as v_state(time) in the notes *) (* Probability of the most probable path ending in state at time *) let rec mostLikelyPathProb (state : 'hidden) (time : int) : float * 'hidden list = let emission = emissionProbs (observation.[time], state) match time with (* If we're at time 0, then just use the emission probability and the prior probability for this state *) | 0 -> emission * priorProbs state, [state] (* If we're not at time 0, then recursively look for the most likely path ending at this time *) | t when t > 0 -> let prob, path = Seq.maxBy fst (seq { for hiddenState in hiddens -> (* Recursively look for most likely path at t - 1 *) let prob, path = mostLikelyPathProb hiddenState (time - 1) (* Rate each path by how likely it is to transition into the current state *) transitions (List.head path, state) * prob, path}) emission * prob, state::path (* If time is < 0, then throw an error *) | _ -> failwith "time must be >= 0" (* Look for the most likely path that ends at t_max *) let prob, revPath = Seq.maxBy fst (seq { for hiddenState in hiddens -> mostLikelyPathProb hiddenState ((Array.length observation) - 1)}) prob, List.rev revPath (* example using data from this article: *) type wikiHiddens = Healthy | Fever let wikiHiddenList = [Healthy; Fever] type wikiObservations = Normal | Cold | Dizzy let wikiPriors = function | Healthy -> 0.6 | Fever -> 0.4 let wikiTransitions = function | (Healthy, Healthy) -> 0.7 | (Healthy, Fever) -> 0.4 | (Fever, Healthy) -> 0.4 | (Fever, Fever) -> 0.6 let wikiEmissions = function | (Cold, Healthy) -> 0.4 | (Normal, Healthy) -> 0.5 | (Dizzy, Healthy) -> 0.1 | (Cold, Fever) -> 0.3 | (Normal, Fever) -> 0.1 | (Dizzy, Fever) -> 0.6 viterbi wikiPriors wikiTransitions wikiEmissions [| Dizzy; Normal; Cold |] wikiHiddenList
Clojure implementation
(ns ident.viterbi (:use [clojure.pprint])) (defstruct hmm :n :m :init-probs :emission-probs :state-transitions) (defn make-hmm [{:keys [states, obs, init-probs, emission-probs, state-transitions]}] (struct-map hmm :n (count states) :m (count obs) :states states :obs obs :init-probs init-probs ;; n dim :emission-probs emission-probs ;;m x n :state-transitions state-transitions)) (defn indexed [s] (map vector (iterate inc 0) s)) (defn argmax [coll] (loop [s (indexed coll) max (first s)] (if (empty? s) max (let [[idx elt] (first s) [max-indx max-elt] max] (if (> elt max-elt) (recur (rest s) (first s)) (recur (rest s) max)))))) (defn pprint-hmm [hmm] (println "number of states: " (:n hmm) " number of observations: " (:m hmm)) (print "init probabilities: ") (pprint (:init-probs hmm)) (print "emission probs: " ) (pprint (:emission-probs hmm)) (print "state-transitions: " ) (pprint (:state-transitions hmm))) (defn init-alphas [hmm obs] (map (fn [x] (* (aget (:init-probs hmm) x) (aget (:emission-probs hmm) x obs))) (range (:n hmm)))) (defn forward [hmm alphas obs] (map (fn [j] (* (reduce (fn [sum i] (+ sum (* (nth alphas i) (aget (:state-transitions hmm) i j)))) 0 (range (:n hmm))) (aget (:emission-probs hmm) j obs))) (range (:n hmm)))) (defn delta-max [hmm deltas obs] (map (fn [j] (* (apply max (map (fn [i] (* (nth deltas i) (aget (:state-transitions hmm) i j))) (range (:n hmm)))) (aget (:emission-probs hmm) j obs))) (range (:n hmm)))) (defn backtrack [paths deltas] (loop [path (reverse paths) term (first (argmax deltas)) backtrack []] (if (empty? path) (reverse (conj backtrack term)) (recur (rest path) (nth (first path) term) (conj backtrack term))))) (defn update-paths [hmm deltas] (map (fn [j] (first (argmax (map (fn [i] (* (nth deltas i) (aget (:state-transitions hmm) i j))) (range (:n hmm)))))) (range (:n hmm)))) (defn viterbi [hmm observs] (loop [obs (rest observs) alphas (init-alphas hmm (first observs)) deltas alphas paths []] (if (empty? obs) [(backtrack paths deltas) (float (reduce + alphas))] (recur (rest obs) (forward hmm alphas (first obs)) (delta-max hmm deltas (first obs)) (conj paths (update-paths hmm deltas))))))
C# implementation
using System; using System.Collections.Generic; using System.Linq; using System.Text; namespace Viterbi { class Program { //Weather states static String HEALTHY = "Healthy"; static String FEVER = "Fever"; //Dependable actions (observations) static String DIZZY = "dizzy"; static String COLD = "cold"; static String NORMAL = "normal"; static void Main(string[] args) { //initialize our arrays of states and observations String[] states = { HEALTHY, FEVER }; String[] observations = { DIZZY, COLD, NORMAL }; var start_probability = new Dictionary<String, float>(); start_probability.Add(HEALTHY, 0.6f); start_probability.Add(FEVER, 0.4f); //Transition probability var transition_probability = new Dictionary<String, Dictionary<String, float>>(); var t1 = new Dictionary<String, float>(); t1.Add(HEALTHY, 0.7f); t1.Add(FEVER, 0.3f); Dictionary<String, float> t2 = new Dictionary<String, float>(); t2.Add(HEALTHY, 0.4f); t2.Add(FEVER, 0.6f); transition_probability.Add(HEALTHY, t1); transition_probability.Add(FEVER, t2); //emission_probability var emission_probability = new Dictionary<String, Dictionary<String, float>>(); var e1 = new Dictionary<String, float>(); e1.Add(DIZZY, 0.1f); e1.Add(COLD, 0.4f); e1.Add(NORMAL, 0.5f); Dictionary<String, float> e2 = new Dictionary<String, float>(); e2.Add(DIZZY, 0.6f); e2.Add(COLD, 0.3f); e2.Add(NORMAL, 0.1f); emission_probability.Add(HEALTHY, e1); emission_probability.Add(FEVER, e2); Object[] ret = forward_viterbi(observations, states, start_probability, transition_probability, emission_probability); Console.WriteLine((float)ret[0]); Console.WriteLine((String)ret[1]); Console.WriteLine((float)ret[2]); Console.ReadLine(); } public static Object[] forward_viterbi(String[] obs, String[] states, Dictionary<String, float> start_p, Dictionary<String, Dictionary<String, float>> trans_p, Dictionary<String, Dictionary<String, float>> emit_p) { var T = new Dictionary<String, Object[]>(); foreach (String state in states) { T.Add(state, new Object[] { start_p[state], state, start_p[state] }); } foreach (String output in obs) { var U = new Dictionary<String, Object[]>(); foreach (String next_state in states) { float total = 0; String argmax = ""; float valmax = 0; float prob = 1; String v_path = ""; float v_prob = 1; foreach (String source_state in states) { Object[] objs = T[source_state]; prob = ((float)objs[0]); v_path = (String)objs[1]; v_prob = ((float)objs[2]); float p = emit_p[source_state][output] * trans_p[source_state][next_state]; prob *= p; v_prob *= p; total += prob; if (v_prob > valmax) { argmax = v_path + "," + next_state; valmax = v_prob; } } U.Add(next_state, new Object[] { total, argmax, valmax }); } T = U; } float xtotal = 0; String xargmax = ""; float xvalmax = 0; float xprob; String xv_path; float xv_prob; foreach (String state in states) { Object[] objs = T[state]; xprob = ((float)objs[0]); xv_path = ((String)objs[1]); xv_prob = ((float)objs[2]); xtotal += xprob; if (xv_prob > xvalmax) { xargmax = xv_path; xvalmax = xv_prob; } } return new Object[] { xtotal, xargmax, xvalmax }; } } }