Markov Chain Text Generation
Description
A Markov Chain is a stochastic model that describes a sequence of possible events where the probability of each event depends only on the state of the previous event. In text generation, it's used to predict the next word or character based on the previous n words or characters.
How It Works
- Training Phase:
- Analyze input text to build state transitions
- Create a probability distribution for each state
- Store transitions in a dictionary/map structure
- Generation Phase:
- Start with an initial state
- Randomly select next state based on probabilities
- Repeat until desired length or end condition
Visualization
from collections import defaultdict
import random
from typing import Dict, List, Tuple
class MarkovChain:
def __init__(self, order: int = 2):
"""Initialize Markov Chain with specified order (chain length)."""
self.order = order
self.transitions = defaultdict(list)
def train(self, text: str) -> None:
"""Train the Markov Chain on input text."""
# Add padding to handle start/end
padded_text = " " * self.order + text + " " * self.order
# Build transitions dictionary
for i in range(len(padded_text) - self.order):
state = padded_text[i:i + self.order]
next_char = padded_text[i + self.order]
self.transitions[state].append(next_char)
def generate(self, length: int = 100) -> str:
"""Generate text using the trained Markov Chain."""
if not self.transitions:
return "Error: Model not trained"
# Start with a random state
current = random.choice(list(self.transitions.keys()))
result = list(current)
# Generate subsequent characters
for _ in range(length - self.order):
if current not in self.transitions:
break
next_char = random.choice(self.transitions[current])
result.append(next_char)
current = current[1:] + next_char
return ''.join(result).strip()
def get_statistics(self) -> Dict[str, int]:
"""Get statistics about the trained model."""
return {
'states': len(self.transitions),
'total_transitions': sum(len(v) for v in self.transitions.values()),
'avg_transitions': sum(len(v) for v in self.transitions.values()) /
len(self.transitions) if self.transitions else 0
}
def main():
# Example usage
text = """
The quick brown fox jumps over the lazy dog.
A quick brown dog jumps over the lazy fox.
The lazy fox and dog are quick to jump.
"""
# Create and train model
markov = MarkovChain(order=2)
markov.train(text)
# Generate text
generated = markov.generate(length=100)
print("Generated text:")
print(generated)
# Print statistics
stats = markov.get_statistics()
print("\nModel Statistics:")
for key, value in stats.items():
print(f"{key}: {value}")
if __name__ == "__main__":
main()
#include <iostream>
#include <string>
#include <map>
#include <vector>
#include <random>
#include <chrono>
class MarkovChain {
private:
int order;
std::map<std::string, std::vector<char>> transitions;
std::mt19937 rng;
public:
MarkovChain(int chainOrder = 2) : order(chainOrder) {
// Initialize random number generator with time-based seed
rng.seed(std::chrono::steady_clock::now().time_since_epoch().count());
}
void train(const std::string& text) {
// Add padding to handle start/end
std::string paddedText = std::string(order, ' ') + text + std::string(order, ' ');
// Build transitions map
for (size_t i = 0; i < paddedText.length() - order; ++i) {
std::string state = paddedText.substr(i, order);
char nextChar = paddedText[i + order];
transitions[state].push_back(nextChar);
}
}
std::string generate(size_t length = 100) {
if (transitions.empty()) {
return "Error: Model not trained";
}
// Start with a random state
std::vector<std::string> states;
for (const auto& pair : transitions) {
states.push_back(pair.first);
}
std::uniform_int_distribution<size_t> stateDist(0, states.size() - 1);
std::string current = states[stateDist(rng)];
std::string result = current;
// Generate subsequent characters
for (size_t i = 0; i < length - order; ++i) {
auto it = transitions.find(current);
if (it == transitions.end() || it->second.empty()) {
break;
}
std::uniform_int_distribution<size_t> charDist(0, it->second.size() - 1);
char nextChar = it->second[charDist(rng)];
result += nextChar;
current = result.substr(result.length() - order);
}
return result;
}
void printStats() const {
size_t totalTransitions = 0;
for (const auto& pair : transitions) {
totalTransitions += pair.second.size();
}
std::cout << "Model Statistics:\n";
std::cout << "States: " << transitions.size() << "\n";
std::cout << "Total transitions: " << totalTransitions << "\n";
std::cout << "Average transitions per state: "
<< (transitions.empty() ? 0.0 :
static_cast<double>(totalTransitions) / transitions.size())
<< "\n";
}
};
int main() {
std::string text = R"(
The quick brown fox jumps over the lazy dog.
A quick brown dog jumps over the lazy fox.
The lazy fox and dog are quick to jump.
)";
MarkovChain markov(2);
markov.train(text);
std::cout << "Generated text:\n";
std::cout << markov.generate(100) << "\n\n";
markov.printStats();
return 0;
}
using System;
using System.Collections.Generic;
using System.Linq;
public class MarkovChain
{
private readonly int order;
private readonly Dictionary<string, List<char>> transitions;
private readonly Random random;
public MarkovChain(int order = 2)
{
this.order = order;
this.transitions = new Dictionary<string, List<char>>();
this.random = new Random();
}
public void Train(string text)
{
// Add padding to handle start/end
string paddedText = new string(' ', order) + text + new string(' ', order);
// Build transitions dictionary
for (int i = 0; i < paddedText.Length - order; i++)
{
string state = paddedText.Substring(i, order);
char nextChar = paddedText[i + order];
if (!transitions.ContainsKey(state))
{
transitions[state] = new List<char>();
}
transitions[state].Add(nextChar);
}
}
public string Generate(int length = 100)
{
if (!transitions.Any())
{
return "Error: Model not trained";
}
// Start with a random state
string current = transitions.Keys.ElementAt(random.Next(transitions.Count));
var result = new List<char>(current);
// Generate subsequent characters
for (int i = 0; i < length - order; i++)
{
if (!transitions.ContainsKey(current))
{
break;
}
var possibleNextChars = transitions[current];
char nextChar = possibleNextChars[random.Next(possibleNextChars.Count)];
result.Add(nextChar);
current = new string(result.Skip(result.Count - order).Take(order).ToArray());
}
return new string(result.ToArray()).Trim();
}
public Dictionary<string, object> GetStatistics()
{
int totalTransitions = transitions.Values.Sum(v => v.Count);
double avgTransitions = transitions.Any()
? (double)totalTransitions / transitions.Count
: 0;
return new Dictionary<string, object>
{
["States"] = transitions.Count,
["TotalTransitions"] = totalTransitions,
["AverageTransitionsPerState"] = avgTransitions
};
}
public static void Main()
{
string text = @"
The quick brown fox jumps over the lazy dog.
A quick brown dog jumps over the lazy fox.
The lazy fox and dog are quick to jump.
";
var markov = new MarkovChain(2);
markov.Train(text);
Console.WriteLine("Generated text:");
Console.WriteLine(markov.Generate(100));
Console.WriteLine();
var stats = markov.GetStatistics();
Console.WriteLine("Model Statistics:");
foreach (var stat in stats)
{
Console.WriteLine($"{stat.Key}: {stat.Value}");
}
}
}
Complexity Analysis
Time Complexity
- Training: O(n), where n is the length of input text
- Generation: O(m), where m is the desired output length
Space Complexity
- O(k), where k is the number of unique states
- k depends on the order of the Markov chain and input text variety
Applications
- Text Generation
- Predictive Text
- Language Modeling
- Speech Synthesis
- Music Composition
Advantages and Disadvantages
Advantages
- Simple to implement
- Fast training and generation
- Memory efficient
- Works with any sequence data
Disadvantages
- Limited context understanding
- Can produce nonsensical output
- No long-term dependencies
- Requires large training data