Quick Info

Category: String Processing
Time Complexity: O(n)
Space Complexity: O(k)
Input Type: Text

Markov Chain Text Generation

Description

A Markov Chain is a stochastic model that describes a sequence of possible events where the probability of each event depends only on the state of the previous event. In text generation, it's used to predict the next word or character based on the previous n words or characters.

How It Works

  1. Training Phase:
    • Analyze input text to build state transitions
    • Create a probability distribution for each state
    • Store transitions in a dictionary/map structure
  2. Generation Phase:
    • Start with an initial state
    • Randomly select next state based on probabilities
    • Repeat until desired length or end condition

Visualization


from collections import defaultdict
import random
from typing import Dict, List, Tuple

class MarkovChain:
    def __init__(self, order: int = 2):
        """Initialize Markov Chain with specified order (chain length)."""
        self.order = order
        self.transitions = defaultdict(list)
        
    def train(self, text: str) -> None:
        """Train the Markov Chain on input text."""
        # Add padding to handle start/end
        padded_text = " " * self.order + text + " " * self.order
        
        # Build transitions dictionary
        for i in range(len(padded_text) - self.order):
            state = padded_text[i:i + self.order]
            next_char = padded_text[i + self.order]
            self.transitions[state].append(next_char)
    
    def generate(self, length: int = 100) -> str:
        """Generate text using the trained Markov Chain."""
        if not self.transitions:
            return "Error: Model not trained"
        
        # Start with a random state
        current = random.choice(list(self.transitions.keys()))
        result = list(current)
        
        # Generate subsequent characters
        for _ in range(length - self.order):
            if current not in self.transitions:
                break
                
            next_char = random.choice(self.transitions[current])
            result.append(next_char)
            current = current[1:] + next_char
        
        return ''.join(result).strip()
    
    def get_statistics(self) -> Dict[str, int]:
        """Get statistics about the trained model."""
        return {
            'states': len(self.transitions),
            'total_transitions': sum(len(v) for v in self.transitions.values()),
            'avg_transitions': sum(len(v) for v in self.transitions.values()) / 
                             len(self.transitions) if self.transitions else 0
        }

def main():
    # Example usage
    text = """
    The quick brown fox jumps over the lazy dog.
    A quick brown dog jumps over the lazy fox.
    The lazy fox and dog are quick to jump.
    """
    
    # Create and train model
    markov = MarkovChain(order=2)
    markov.train(text)
    
    # Generate text
    generated = markov.generate(length=100)
    print("Generated text:")
    print(generated)
    
    # Print statistics
    stats = markov.get_statistics()
    print("\nModel Statistics:")
    for key, value in stats.items():
        print(f"{key}: {value}")

if __name__ == "__main__":
    main()

#include <iostream>
#include <string>
#include <map>
#include <vector>
#include <random>
#include <chrono>

class MarkovChain {
private:
    int order;
    std::map<std::string, std::vector<char>> transitions;
    std::mt19937 rng;

public:
    MarkovChain(int chainOrder = 2) : order(chainOrder) {
        // Initialize random number generator with time-based seed
        rng.seed(std::chrono::steady_clock::now().time_since_epoch().count());
    }

    void train(const std::string& text) {
        // Add padding to handle start/end
        std::string paddedText = std::string(order, ' ') + text + std::string(order, ' ');

        // Build transitions map
        for (size_t i = 0; i < paddedText.length() - order; ++i) {
            std::string state = paddedText.substr(i, order);
            char nextChar = paddedText[i + order];
            transitions[state].push_back(nextChar);
        }
    }

    std::string generate(size_t length = 100) {
        if (transitions.empty()) {
            return "Error: Model not trained";
        }

        // Start with a random state
        std::vector<std::string> states;
        for (const auto& pair : transitions) {
            states.push_back(pair.first);
        }
        
        std::uniform_int_distribution<size_t> stateDist(0, states.size() - 1);
        std::string current = states[stateDist(rng)];
        std::string result = current;

        // Generate subsequent characters
        for (size_t i = 0; i < length - order; ++i) {
            auto it = transitions.find(current);
            if (it == transitions.end() || it->second.empty()) {
                break;
            }

            std::uniform_int_distribution<size_t> charDist(0, it->second.size() - 1);
            char nextChar = it->second[charDist(rng)];
            result += nextChar;
            current = result.substr(result.length() - order);
        }

        return result;
    }

    void printStats() const {
        size_t totalTransitions = 0;
        for (const auto& pair : transitions) {
            totalTransitions += pair.second.size();
        }

        std::cout << "Model Statistics:\n";
        std::cout << "States: " << transitions.size() << "\n";
        std::cout << "Total transitions: " << totalTransitions << "\n";
        std::cout << "Average transitions per state: " 
                  << (transitions.empty() ? 0.0 : 
                      static_cast<double>(totalTransitions) / transitions.size())
                  << "\n";
    }
};

int main() {
    std::string text = R"(
        The quick brown fox jumps over the lazy dog.
        A quick brown dog jumps over the lazy fox.
        The lazy fox and dog are quick to jump.
    )";

    MarkovChain markov(2);
    markov.train(text);

    std::cout << "Generated text:\n";
    std::cout << markov.generate(100) << "\n\n";

    markov.printStats();

    return 0;
}

using System;
using System.Collections.Generic;
using System.Linq;

public class MarkovChain
{
    private readonly int order;
    private readonly Dictionary<string, List<char>> transitions;
    private readonly Random random;

    public MarkovChain(int order = 2)
    {
        this.order = order;
        this.transitions = new Dictionary<string, List<char>>();
        this.random = new Random();
    }

    public void Train(string text)
    {
        // Add padding to handle start/end
        string paddedText = new string(' ', order) + text + new string(' ', order);

        // Build transitions dictionary
        for (int i = 0; i < paddedText.Length - order; i++)
        {
            string state = paddedText.Substring(i, order);
            char nextChar = paddedText[i + order];

            if (!transitions.ContainsKey(state))
            {
                transitions[state] = new List<char>();
            }
            transitions[state].Add(nextChar);
        }
    }

    public string Generate(int length = 100)
    {
        if (!transitions.Any())
        {
            return "Error: Model not trained";
        }

        // Start with a random state
        string current = transitions.Keys.ElementAt(random.Next(transitions.Count));
        var result = new List<char>(current);

        // Generate subsequent characters
        for (int i = 0; i < length - order; i++)
        {
            if (!transitions.ContainsKey(current))
            {
                break;
            }

            var possibleNextChars = transitions[current];
            char nextChar = possibleNextChars[random.Next(possibleNextChars.Count)];
            result.Add(nextChar);
            current = new string(result.Skip(result.Count - order).Take(order).ToArray());
        }

        return new string(result.ToArray()).Trim();
    }

    public Dictionary<string, object> GetStatistics()
    {
        int totalTransitions = transitions.Values.Sum(v => v.Count);
        double avgTransitions = transitions.Any() 
            ? (double)totalTransitions / transitions.Count 
            : 0;

        return new Dictionary<string, object>
        {
            ["States"] = transitions.Count,
            ["TotalTransitions"] = totalTransitions,
            ["AverageTransitionsPerState"] = avgTransitions
        };
    }

    public static void Main()
    {
        string text = @"
            The quick brown fox jumps over the lazy dog.
            A quick brown dog jumps over the lazy fox.
            The lazy fox and dog are quick to jump.
        ";

        var markov = new MarkovChain(2);
        markov.Train(text);

        Console.WriteLine("Generated text:");
        Console.WriteLine(markov.Generate(100));
        Console.WriteLine();

        var stats = markov.GetStatistics();
        Console.WriteLine("Model Statistics:");
        foreach (var stat in stats)
        {
            Console.WriteLine($"{stat.Key}: {stat.Value}");
        }
    }
}

Complexity Analysis

Time Complexity

  • Training: O(n), where n is the length of input text
  • Generation: O(m), where m is the desired output length

Space Complexity

  • O(k), where k is the number of unique states
  • k depends on the order of the Markov chain and input text variety

Applications

  • Text Generation
  • Predictive Text
  • Language Modeling
  • Speech Synthesis
  • Music Composition

Advantages and Disadvantages

Advantages

  • Simple to implement
  • Fast training and generation
  • Memory efficient
  • Works with any sequence data

Disadvantages

  • Limited context understanding
  • Can produce nonsensical output
  • No long-term dependencies
  • Requires large training data