其他分享
首页 > 其他分享> > 2021-09-29

2021-09-29

作者:互联网

from collections import defaultdict

states = ("Rainy", "Sunny")
observations = ("Walk", "Shop", "Clean")
start_probability = {"Rainy": 0.6, "Sunny": 0.4}
transition_probability = {
    "Rainy": {"Rainy": 0.7, "Sunny": 0.3},
    "Sunny": {"Rainy": 0.4, "Sunny": 0.6},
}

emission_probability = {
    "Rainy": {"Walk": 0.1, "Shop": 0.4, "Clean": 0.5},
    "Sunny": {"Walk": 0.6, "Shop": 0.3, "Clean": 0.1},
}

def compute(obs, states, start_p, trans_p, emit_p):
    v = [{} for _ in range(len(obs))]
    path = defaultdict(list)
    for state in states:
        v[0][state] = start_p[state] * emit_p[state][obs[0]]
        path[state].append(state)
    for t in range(1, len(obs)):
        for y1 in states:
            max_prob = -1
            for y0 in v[t - 1]:
                nprob = v[t - 1][y0] * trans_p[y0][y1] * emit_p[y1][obs[t]]
                if nprob > max_prob:
                    max_prob = nprob
                    max_state = y0
            v[t][y1] = max_prob
            newpath = []
            for state1 in path[max_state]:
                newpath.append(state1)
            newpath.append(y1)
            path[y1] = newpath
    prob = -1
    for y1 in states:
        if v[len(obs) - 1][y1] > prob:
            prob = v[len(obs) - 1][y1]
            state = y1
    return path[state]


if __name__ == "__main__":
    max_path = compute(observations, states, start_probability,
                       transition_probability, emission_probability)
    print(max_path)

在这里插入图片描述

标签:max,09,29,obs,state,2021,y1,path,prob
来源: https://blog.csdn.net/llacr/article/details/120558314