"""
Sionna NR Resource Grid Generator - SSB (Synchronization Signal Block)
- 20 MHz Bandwidth (configurable)
- 15 kHz Subcarrier Spacing (for SSB Case A/B) or 30 kHz (for Case C/D/E)
- SSB occupies 20 RBs (240 subcarriers) x 4 OFDM symbols
- Case A: 3-6 GHz, max 8 SSB bursts in 5ms
- Shows 5ms (half frame) of resource grid

3GPP TS 38.211 - SSB Structure:
- Symbol 0: PSS (127 subcarriers, centered)
- Symbol 1: PBCH + PBCH DMRS
- Symbol 2: PBCH + SSS (127 subcarriers, centered) + PBCH DMRS
- Symbol 3: PBCH + PBCH DMRS
"""

import numpy as np
import tensorflow as tf
import matplotlib
matplotlib.use('TkAgg')  # Use TkAgg backend for interactive display
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk
from matplotlib.figure import Figure
from matplotlib.patches import Rectangle, Patch
from matplotlib.colors import ListedColormap
import tkinter as tk
from tkinter import ttk

# =============================================================================
# Debug Flag
# =============================================================================
DEBUG = True  # Set to False to disable debug prints

# =============================================================================
# SSB Configuration Parameters (Global Variables for Future Extension)
# =============================================================================

# SSB Case Configuration (3GPP TS 38.213 Section 4.1)
# Case A: 15 kHz SCS, FR1 (< 6 GHz)
# Case B: 30 kHz SCS, FR1 (< 6 GHz)
# Case C: 30 kHz SCS, FR1 (< 6 GHz) - paired spectrum
# Case D: 120 kHz SCS, FR2 (> 6 GHz)
# Case E: 240 kHz SCS, FR2 (> 6 GHz)
SSB_CASE = 'B'                    # SSB case ('A', 'B', 'C', 'D', 'E')

# L_max Configuration (3GPP TS 38.213 Section 4.1)
# Maximum number of SS/PBCH blocks in a half frame (5ms)
# FR1 (< 3 GHz):   L_max = 4
# FR1 (3-6 GHz):   L_max = 8
# FR2 (> 6 GHz):   L_max = 64
L_MAX = 8                         # Maximum SSB bursts per half frame

# SSB transmission bitmap (length up to L_MAX), '1' = transmit, '0' = skip
# Example for Case A (L_MAX=4): '1111' transmits all 4 bursts
# Example for Case D (L_MAX=64): '11110000...' length 64
#SSB_TX_BITMAP = '1010'
#SSB_TX_BITMAP = '10101011'
SSB_TX_BITMAP = '11111111'
#SSB_TX_BITMAP = '1010101110101011101010111010101110101011101010111010101110101011'

# SSB Case Parameters (3GPP TS 38.213 Table 4.1-1)
SSB_CASE_PARAMS = {
    'A': {'scs_khz': 15, 'first_symbols': [2, 8], 'slot_pattern_period': 14},
    'B': {'scs_khz': 30, 'first_symbols': [4, 8, 16, 20], 'slot_pattern_period': 28},
    'C': {'scs_khz': 30, 'first_symbols': [2, 8], 'slot_pattern_period': 14},
    'D': {'scs_khz': 120, 'first_symbols': [4, 8, 16, 20], 'slot_pattern_period': 28},
    'E': {'scs_khz': 240, 'first_symbols': [8, 12, 16, 20, 32, 36, 40, 44], 'slot_pattern_period': 56},
}

# Get SSB parameters based on case
SSB_SCS_KHZ = SSB_CASE_PARAMS[SSB_CASE]['scs_khz']
SSB_FIRST_SYMBOLS = SSB_CASE_PARAMS[SSB_CASE]['first_symbols']
SSB_SUBCARRIER_SPACING = SSB_SCS_KHZ * 1e3  # Convert to Hz

# =============================================================================
# Channel Bandwidth Configuration
# =============================================================================

CHANNEL_BW_MHZ = 50               # Channel bandwidth in MHz
SUBCARRIER_SPACING = SSB_SUBCARRIER_SPACING  # Use SSB SCS for the grid

# Maximum RBs per channel bandwidth and SCS (3GPP TS 38.101)
MAX_RBS_TABLE = {
    5:   {15: 25,  30: 11,  60: None, 120: None, 240: None},
    10:  {15: 52,  30: 24,  60: 11,   120: None, 240: None},
    15:  {15: 79,  30: 38,  60: 18,   120: None, 240: None},
    20:  {15: 106, 30: 51,  60: 24,   120: None, 240: None},
    25:  {15: 133, 30: 65,  60: 31,   120: None, 240: None},
    30:  {15: 160, 30: 78,  60: 38,   120: None, 240: None},
    40:  {15: 216, 30: 106, 60: 51,   120: None, 240: None},
    50:  {15: 270, 30: 133, 60: 65,   120: 32,   240: None},
    60:  {15: None, 30: 162, 60: 79,  120: 39,   240: None},
    80:  {15: None, 30: 217, 60: 107, 120: 52,   240: None},
    100: {15: None, 30: 273, 60: 135, 120: 66,   240: 32},
}

# FFT sizes for different bandwidths
FFT_SIZE_TABLE = {
    5: 512,
    10: 1024,
    15: 1536,
    20: 2048,
    25: 2048,
    30: 2048,
    40: 4096,
    50: 4096,
    60: 4096,
    80: 4096,
    100: 4096,
}

# Get maximum RBs for the configured bandwidth and SCS
MAX_RBS = MAX_RBS_TABLE.get(CHANNEL_BW_MHZ, {}).get(SSB_SCS_KHZ, None)
if MAX_RBS is None:
    raise ValueError(f"Invalid combination: {CHANNEL_BW_MHZ} MHz with {SSB_SCS_KHZ} kHz SCS")

# =============================================================================
# SSB Physical Parameters (3GPP TS 38.211 Section 7.4.3)
# =============================================================================

SSB_NUM_RBS = 20                  # SSB always occupies 20 RBs (240 subcarriers)
SSB_NUM_SUBCARRIERS = SSB_NUM_RBS * 12  # 240 subcarriers
SSB_NUM_SYMBOLS = 4               # SSB always occupies 4 OFDM symbols

# PSS/SSS Parameters (3GPP TS 38.211 Section 7.4.3.1)
PSS_SSS_NUM_SUBCARRIERS = 127     # PSS and SSS each occupy 127 subcarriers
PSS_SSS_START_SC = 56             # Starting subcarrier within SSB (centered)
PSS_SSS_END_SC = 183              # Ending subcarrier within SSB (exclusive, so 56-182)

# PBCH Parameters (3GPP TS 38.211 Section 7.4.3.1)
# In Symbol 2, PBCH is only on the sides around SSS with gaps
PBCH_LEFT_END_SC = 48             # Left PBCH ends at SC 47 (SC 0-47)
PBCH_RIGHT_START_SC = 192         # Right PBCH starts at SC 192 (SC 192-239)
# Unused REs in Symbol 2: SC 48-55 (between left PBCH and SSS)
#                         SC 183-191 (between SSS and right PBCH)
PBCH_NUM_RES = 432                # PBCH occupies 432 REs per SSB
PBCH_DMRS_DENSITY = 3             # DMRS on every 4th subcarrier (v=0,1,2,3)

# =============================================================================
# Time Domain Configuration for 5ms Display
# =============================================================================

# Calculate number of slots in 5ms (half frame)
# 1 slot = 14 OFDM symbols with normal CP
# Slot duration = 14 symbols * (1/SCS) = 14 * (1/15000) = 0.933ms for 15kHz
# Number of slots in 5ms = 5ms / slot_duration

SLOT_DURATION_MS = 1.0 / (SSB_SCS_KHZ / 15)  # ms per slot (15kHz=1ms, 30kHz=0.5ms)
HALF_FRAME_MS = 5.0               # 5ms half frame
NUM_SLOTS_IN_HALF_FRAME = int(HALF_FRAME_MS / SLOT_DURATION_MS)
NUM_SYMBOLS_PER_SLOT = 14         # Normal CP
NUM_OFDM_SYMBOLS = NUM_SLOTS_IN_HALF_FRAME * NUM_SYMBOLS_PER_SLOT

# FFT size based on channel bandwidth
FFT_SIZE = FFT_SIZE_TABLE.get(CHANNEL_BW_MHZ, 2048)

# Calculate guard carriers
MAX_SUBCARRIERS = MAX_RBS * 12
BASE_GUARD_TOTAL = FFT_SIZE - MAX_SUBCARRIERS
BASE_GUARD_LEFT = BASE_GUARD_TOTAL // 2
BASE_GUARD_RIGHT = BASE_GUARD_TOTAL - BASE_GUARD_LEFT

# =============================================================================
# SSB Position in Frequency Domain (3GPP offsetToPointA / k_SSB)
# =============================================================================

# offsetToPointA: in RB units at 15 kHz common raster
# k_SSB: subcarrier offset (0..23) at 15 kHz common raster
# COMMON_SCS_KHZ: common raster (15 kHz for FR1)
COMMON_SCS_KHZ = 15

# Center the SSB in the configured channel bandwidth using 3GPP raster rules.
# Target start RB at the SSB SCS: mid = floor((MAX_RBS - SSB_NUM_RBS)/2).
_center_start_rb = max(0, (MAX_RBS - SSB_NUM_RBS) // 2)
# Convert that RB start (at SSB SCS) to common 15 kHz raster subcarriers.
_center_start_sc_common = int(round(_center_start_rb * 12 * (SSB_SCS_KHZ / COMMON_SCS_KHZ)))
# Derive offsetToPointA (RB) and k_SSB (SC) on 15 kHz raster.
OFFSET_TO_POINTA_RB = _center_start_sc_common // 12 - 20
K_SSB = _center_start_sc_common % 12

# Compute actual SSB start subcarrier at the SSB SCS from those offsets.
SSB_START_SC = int((OFFSET_TO_POINTA_RB * 12 + K_SSB) * (COMMON_SCS_KHZ / SSB_SCS_KHZ))
SSB_OFFSET_RB = SSB_START_SC // 12

# Clamp to channel if needed (rare)
if SSB_START_SC + SSB_NUM_SUBCARRIERS > MAX_SUBCARRIERS:
    SSB_START_SC = max(0, MAX_SUBCARRIERS - SSB_NUM_SUBCARRIERS)
    SSB_OFFSET_RB = SSB_START_SC // 12
    print(f"WARNING: SSB shifted to fit in channel. New start SC: {SSB_START_SC}")

# =============================================================================
# Calculate SSB Burst Positions (3GPP TS 38.213 Section 4.1)
# =============================================================================

def get_ssb_symbol_positions(ssb_case, num_ssb, l_max):
    """
    Get SSB starting symbol positions for a given case.
    
    Parameters:
    -----------
    ssb_case : str
        SSB case ('A', 'B', 'C', 'D', 'E')
    num_ssb : int
        Number of SSBs to transmit
    l_max : int
        Maximum number of SSBs (4, 8, or 64)
    
    Returns:
    --------
    list: Starting symbol indices for each SSB burst
    """
    params = SSB_CASE_PARAMS[ssb_case]
    first_symbols = params['first_symbols']
    period = params['slot_pattern_period']
    
    ssb_positions = []
    
    if ssb_case == 'A':
        # Case A: Symbols {2, 8} + 14*n, n = 0, 1, 2, 3 for L_max = 8
        # For L_max = 4: n = 0, 1
        max_n = l_max // 2
        for n in range(max_n):
            for first_sym in first_symbols:
                sym = first_sym + period * n
                ssb_positions.append(sym)
                if len(ssb_positions) >= num_ssb:
                    return ssb_positions[:num_ssb]
    
    elif ssb_case == 'B':
        # Case B: Symbols {4, 8, 16, 20} + 28*n
        max_n = l_max // 4
        for n in range(max_n):
            for first_sym in first_symbols:
                sym = first_sym + period * n
                ssb_positions.append(sym)
                if len(ssb_positions) >= num_ssb:
                    return ssb_positions[:num_ssb]
    
    elif ssb_case == 'C':
        # Case C: Symbols {2, 8} + 14*n
        max_n = l_max // 2
        for n in range(max_n):
            for first_sym in first_symbols:
                sym = first_sym + period * n
                ssb_positions.append(sym)
                if len(ssb_positions) >= num_ssb:
                    return ssb_positions[:num_ssb]
    
    elif ssb_case in ['D', 'E']:
        # Case D/E: More complex patterns for FR2
        max_n = l_max // len(first_symbols)
        for n in range(max_n):
            for first_sym in first_symbols:
                sym = first_sym + period * n
                ssb_positions.append(sym)
                if len(ssb_positions) >= num_ssb:
                    return ssb_positions[:num_ssb]
    
    return ssb_positions[:num_ssb]

# Normalize SSB transmission bitmap
def normalize_ssb_bitmap(bitmap, l_max):
    """Keep only 0/1, pad or trim to l_max."""
    b = ''.join(ch for ch in bitmap if ch in ['0', '1'])
    if len(b) < l_max:
        b = b.ljust(l_max, '0')
    elif len(b) > l_max:
        b = b[:l_max]
    return b

def apply_ssb_bitmap(ssb_positions_all, bitmap):
    """Select SSB positions based on bitmap bits (1=transmit, 0=skip)."""
    return [pos for pos, bit in zip(ssb_positions_all, bitmap) if bit == '1']

# Get all possible SSB positions for this case (up to L_MAX)
SSB_TX_BITMAP_NORM = normalize_ssb_bitmap(SSB_TX_BITMAP, L_MAX)
SSB_POSITIONS_ALL = get_ssb_symbol_positions(SSB_CASE, L_MAX, L_MAX)
SSB_ACTIVE = [(idx, pos) for idx, (bit, pos) in enumerate(zip(SSB_TX_BITMAP_NORM, SSB_POSITIONS_ALL)) if bit == '1']
SSB_SYMBOL_POSITIONS = [pos for _, pos in SSB_ACTIVE]
SSB_SYMBOL_INDICES = [idx for idx, _ in SSB_ACTIVE]
NUM_SSB_TRANSMITTED = len(SSB_SYMBOL_POSITIONS)

# =============================================================================
# Create SSB Resource Element Masks
# =============================================================================

def create_ssb_masks(num_symbols, num_subcarriers, ssb_start_sc, ssb_symbol_positions):
    """
    Create masks for different SSB components.
    
    SSB Structure (within the 4 symbols x 240 subcarriers):
    - Symbol 0: PSS (SC 56-182, 127 subcarriers)
    - Symbol 1: PBCH (all 240 SC) + PBCH DMRS (every 4th SC, offset by v)
    - Symbol 2: PBCH (SC 0-47 & 192-239) + SSS (SC 56-182) + PBCH DMRS
    - Symbol 3: PBCH (all 240 SC) + PBCH DMRS
    
    Parameters:
    -----------
    num_symbols : int
        Total number of OFDM symbols in the grid
    num_subcarriers : int
        Total number of subcarriers in the channel
    ssb_start_sc : int
        Starting subcarrier of SSB in the channel
    ssb_symbol_positions : list
        Starting symbol indices for each SSB burst
    
    Returns:
    --------
    dict: Masks for PSS, SSS, PBCH, PBCH_DMRS, and combined SSB
    """
    # Initialize masks
    pss_mask = np.zeros((num_symbols, num_subcarriers), dtype=bool)
    sss_mask = np.zeros((num_symbols, num_subcarriers), dtype=bool)
    pbch_mask = np.zeros((num_symbols, num_subcarriers), dtype=bool)
    pbch_dmrs_mask = np.zeros((num_symbols, num_subcarriers), dtype=bool)
    ssb_combined_mask = np.zeros((num_symbols, num_subcarriers), dtype=bool)
    
    # SSB subcarrier boundaries in channel
    ssb_end_sc = ssb_start_sc + SSB_NUM_SUBCARRIERS
    
    # PSS/SSS position within SSB (centered, 127 subcarriers)
    pss_sss_start = ssb_start_sc + PSS_SSS_START_SC
    pss_sss_end = ssb_start_sc + PSS_SSS_END_SC
    
    for ssb_idx, ssb_start_sym in enumerate(ssb_symbol_positions):
        if ssb_start_sym + SSB_NUM_SYMBOLS > num_symbols:
            continue  # Skip if SSB doesn't fit in the grid
        
        # SSB Symbol 0: PSS
        sym = ssb_start_sym + 0
        for sc in range(pss_sss_start, pss_sss_end):
            if sc < num_subcarriers:
                pss_mask[sym, sc] = True
                ssb_combined_mask[sym, sc] = True
        
        # SSB Symbol 1: PBCH + PBCH DMRS
        sym = ssb_start_sym + 1
        for sc_offset in range(SSB_NUM_SUBCARRIERS):
            sc = ssb_start_sc + sc_offset
            if sc >= num_subcarriers:
                continue
            # PBCH DMRS on every 4th subcarrier (0, 4, 8, ...)
            if sc_offset % 4 == 0:
                pbch_dmrs_mask[sym, sc] = True
            else:
                pbch_mask[sym, sc] = True
            ssb_combined_mask[sym, sc] = True
        
        # SSB Symbol 2: SSS (center) + PBCH (sides) + PBCH DMRS
        # Structure: PBCH(0-47) | unused(48-55) | SSS(56-182) | unused(183-191) | PBCH(192-239)
        sym = ssb_start_sym + 2
        for sc_offset in range(SSB_NUM_SUBCARRIERS):
            sc = ssb_start_sc + sc_offset
            if sc >= num_subcarriers:
                continue
            
            # SSS is in the center (SC 56-182 within SSB)
            if PSS_SSS_START_SC <= sc_offset < PSS_SSS_END_SC:
                sss_mask[sym, sc] = True
                ssb_combined_mask[sym, sc] = True
            # Left PBCH (SC 0-47)
            elif sc_offset < PBCH_LEFT_END_SC:
                if sc_offset % 4 == 0:
                    pbch_dmrs_mask[sym, sc] = True
                else:
                    pbch_mask[sym, sc] = True
                ssb_combined_mask[sym, sc] = True
            # Right PBCH (SC 192-239)
            elif sc_offset >= PBCH_RIGHT_START_SC:
                if sc_offset % 4 == 0:
                    pbch_dmrs_mask[sym, sc] = True
                else:
                    pbch_mask[sym, sc] = True
                ssb_combined_mask[sym, sc] = True
            # Unused REs (SC 48-55 and SC 183-191) - leave as zeros
        
        # SSB Symbol 3: PBCH + PBCH DMRS
        sym = ssb_start_sym + 3
        for sc_offset in range(SSB_NUM_SUBCARRIERS):
            sc = ssb_start_sc + sc_offset
            if sc >= num_subcarriers:
                continue
            # PBCH DMRS on every 4th subcarrier
            if sc_offset % 4 == 0:
                pbch_dmrs_mask[sym, sc] = True
            else:
                pbch_mask[sym, sc] = True
            ssb_combined_mask[sym, sc] = True
    
    return {
        'pss': pss_mask,
        'sss': sss_mask,
        'pbch': pbch_mask,
        'pbch_dmrs': pbch_dmrs_mask,
        'ssb_combined': ssb_combined_mask
    }

# =============================================================================
# Generate SSB Sequences
# =============================================================================

def generate_pss_sequence(n_id_2):
    """
    Generate PSS sequence (3GPP TS 38.211 Section 7.4.2.2.1).
    
    PSS is a frequency-domain sequence of length 127.
    d_PSS(n) = 1 - 2*x(m), where x is based on m-sequence.
    
    Parameters:
    -----------
    n_id_2 : int
        Physical layer cell identity group (0, 1, or 2)
    
    Returns:
    --------
    np.array: Complex PSS sequence of length 127
    """
    # M-sequence generator polynomial: x^7 + x^4 + 1
    # Initial state: [1,1,1,0,1,1,0]
    x = np.zeros(127, dtype=int)
    
    # Initialize shift register
    shift_reg = [1, 1, 1, 0, 1, 1, 0]  # Initial state
    
    for i in range(127):
        x[i] = shift_reg[6]
        # Feedback: x(i+7) = x(i+4) XOR x(i) mod 2
        new_bit = shift_reg[3] ^ shift_reg[0]
        shift_reg = [new_bit] + shift_reg[:-1]
    
    # Generate d_PSS with cyclic shift based on N_ID^(2)
    d_pss = np.zeros(127, dtype=complex)
    for n in range(127):
        m = (n + 43 * n_id_2) % 127
        d_pss[n] = 1 - 2 * x[m]
    
    return d_pss


def generate_sss_sequence(n_id_1, n_id_2):
    """
    Generate SSS sequence (3GPP TS 38.211 Section 7.4.2.3.1).
    
    SSS is a frequency-domain sequence of length 127.
    
    Parameters:
    -----------
    n_id_1 : int
        Physical layer cell identity (0-335)
    n_id_2 : int
        Physical layer cell identity group (0, 1, or 2)
    
    Returns:
    --------
    np.array: Complex SSS sequence of length 127
    """
    # Two m-sequences x0 and x1 (3GPP TS 38.211 Section 7.4.2.3.1)
    # Use all-ones initialization to avoid degenerate sequences
    x0 = np.zeros(127, dtype=int)
    x1 = np.zeros(127, dtype=int)
    
    # Initialize x0: x(0..6) = [1,1,1,1,1,1,1]
    # x0(i+7) = (x0(i+4) + x0(i)) mod 2
    shift_reg0 = [1, 1, 1, 1, 1, 1, 1]  # [x(i), x(i+1), ..., x(i+6)]
    for i in range(127):
        x0[i] = shift_reg0[6]
        new_bit = shift_reg0[4] ^ shift_reg0[0]  # x(i+4) XOR x(i)
        shift_reg0 = [new_bit] + shift_reg0[:-1]
    
    # Initialize x1: x(0..6) = [1,1,1,1,1,1,1]
    # x1(i+7) = (x1(i+1) + x1(i)) mod 2  (gives good spreading; avoids constant)
    shift_reg1 = [1, 1, 1, 1, 1, 1, 1]  # [x(i), x(i+1), ..., x(i+6)]
    for i in range(127):
        x1[i] = shift_reg1[6]
        new_bit = shift_reg1[1] ^ shift_reg1[0]  # x(i+1) XOR x(i)
        shift_reg1 = [new_bit] + shift_reg1[:-1]
    
    # Calculate m0 and m1
    m0 = 15 * (n_id_1 // 112) + 5 * n_id_2
    m1 = n_id_1 % 112
    
    # Generate SSS
    d_sss = np.zeros(127, dtype=complex)
    for n in range(127):
        d_sss[n] = (1 - 2 * x0[(n + m0) % 127]) * (1 - 2 * x1[(n + m1) % 127])
    
    return d_sss


def generate_pbch_dmrs_sequence(n_cell_id, ssb_index, l_ssb_max):
    """
    Generate PBCH DMRS sequence (3GPP TS 38.211 Section 7.4.1.4.1).
    
    Parameters:
    -----------
    n_cell_id : int
        Physical cell ID (0-1007)
    ssb_index : int
        SSB index (0 to L_max-1)
    l_ssb_max : int
        Maximum number of SSBs
    
    Returns:
    --------
    np.array: Complex DMRS sequence
    """
    # DMRS sequence length for PBCH (depends on L_max)
    # For L_max <= 8: 2 bits for SSB index, 144 DMRS REs
    # For L_max = 64: 6 bits for SSB index, 144 DMRS REs
    
    num_dmrs_per_symbol = 60  # 240/4 = 60 DMRS REs per symbol
    num_dmrs_symbols = 3      # Symbols 1, 2, 3
    total_dmrs = num_dmrs_per_symbol * num_dmrs_symbols
    
    # Simplified DMRS sequence (should be Gold sequence)
    np.random.seed(n_cell_id * 1000 + ssb_index)
    qpsk = np.array([1+1j, 1-1j, -1+1j, -1-1j]) / np.sqrt(2)
    dmrs_sequence = qpsk[np.random.randint(0, 4, total_dmrs)]
    
    return dmrs_sequence


# =============================================================================
# Cell ID Configuration
# =============================================================================

N_CELL_ID = 0                     # Physical Cell ID (0-1007)
N_ID_1 = N_CELL_ID // 3           # N_ID^(1) = floor(N_cell_ID / 3)
N_ID_2 = N_CELL_ID % 3            # N_ID^(2) = N_cell_ID mod 3

# Generate PSS and SSS sequences
PSS_SEQUENCE = generate_pss_sequence(N_ID_2)
SSS_SEQUENCE = generate_sss_sequence(N_ID_1, N_ID_2)

# =============================================================================
# Create SSB Masks
# =============================================================================

print("=" * 70)
print("SSB Resource Grid Configuration")
print("=" * 70)
print(f"SSB Case:                 {SSB_CASE}")
print(f"SSB Subcarrier Spacing:   {SSB_SCS_KHZ} kHz")
print(f"L_max:                    {L_MAX}")
print(f"SSB TX Bitmap:            {SSB_TX_BITMAP_NORM}")
print(f"SSBs Transmitted:         {NUM_SSB_TRANSMITTED}")
print(f"SSB Symbol Positions:     {SSB_SYMBOL_POSITIONS}")
print(f"SSB Indices (bitmap):     {SSB_SYMBOL_INDICES}")
print(f"offsetToPointA (RB@15k):  {OFFSET_TO_POINTA_RB}")
print(f"k_SSB (SC@15k):           {K_SSB}")
print("-" * 70)
print(f"Channel Bandwidth:        {CHANNEL_BW_MHZ} MHz")
print(f"Max RBs in Channel:       {MAX_RBS}")
print(f"FFT Size:                 {FFT_SIZE}")
print(f"Guard Carriers:           {BASE_GUARD_LEFT} (left), {BASE_GUARD_RIGHT} (right)")
print("-" * 70)
print(f"Half Frame Duration:      {HALF_FRAME_MS} ms")
print(f"Slot Duration:            {SLOT_DURATION_MS} ms")
print(f"Slots in Half Frame:      {NUM_SLOTS_IN_HALF_FRAME}")
print(f"Total OFDM Symbols:       {NUM_OFDM_SYMBOLS}")
print("-" * 70)
print(f"SSB Size:                 {SSB_NUM_RBS} RBs x {SSB_NUM_SYMBOLS} symbols")
print(f"SSB Offset (RB):          {SSB_OFFSET_RB}")
print(f"SSB Start Subcarrier:     {SSB_START_SC}")
print("-" * 70)
print(f"Cell ID:                  {N_CELL_ID}")
print(f"N_ID^(1):                 {N_ID_1}")
print(f"N_ID^(2):                 {N_ID_2}")
print("=" * 70)

# Create SSB masks
SSB_MASKS = create_ssb_masks(
    NUM_OFDM_SYMBOLS, 
    MAX_SUBCARRIERS,
    SSB_START_SC,
    SSB_SYMBOL_POSITIONS
)

# Calculate RE statistics
pss_res = np.sum(SSB_MASKS['pss'])
sss_res = np.sum(SSB_MASKS['sss'])
pbch_res = np.sum(SSB_MASKS['pbch'])
pbch_dmrs_res = np.sum(SSB_MASKS['pbch_dmrs'])
ssb_total_res = np.sum(SSB_MASKS['ssb_combined'])

print(f"\nSSB Resource Element Statistics:")
print(f"  PSS REs:           {pss_res}")
print(f"  SSS REs:           {sss_res}")
print(f"  PBCH REs:          {pbch_res}")
print(f"  PBCH DMRS REs:     {pbch_dmrs_res}")
print(f"  Total SSB REs:     {ssb_total_res}")
print(f"  REs per SSB:       {SSB_NUM_SUBCARRIERS * SSB_NUM_SYMBOLS}")

# =============================================================================
# Create Resource Grid Display Data
# =============================================================================

def create_resource_grid_display(num_symbols, num_subcarriers, ssb_masks):
    """
    Create a resource grid array for visualization.
    
    Values:
    0 = Empty/Unused
    1 = PSS
    2 = SSS
    3 = PBCH
    4 = PBCH DMRS
    
    Parameters:
    -----------
    num_symbols : int
        Total OFDM symbols
    num_subcarriers : int
        Total subcarriers
    ssb_masks : dict
        Dictionary of SSB component masks
    
    Returns:
    --------
    np.array: Resource grid with values indicating RE type
    """
    grid = np.zeros((num_symbols, num_subcarriers), dtype=float)
    
    # Fill in order of priority (later overwrites earlier)
    # 1 = PSS (red)
    grid[ssb_masks['pss']] = 1
    # 2 = SSS (blue)
    grid[ssb_masks['sss']] = 2
    # 3 = PBCH (green)
    grid[ssb_masks['pbch']] = 3
    # 4 = PBCH DMRS (yellow)
    grid[ssb_masks['pbch_dmrs']] = 4
    
    return grid

# Create display grid
RESOURCE_GRID_DISPLAY = create_resource_grid_display(
    NUM_OFDM_SYMBOLS, MAX_SUBCARRIERS, SSB_MASKS
)

# =============================================================================
# Tabbed GUI Visualization
# =============================================================================

class SSBResourceGridViewer:
    """
    Tabbed GUI for viewing SSB Resource Grid visualizations
    """
    def __init__(self, resource_grid_display, ssb_masks, ssb_info):
        self.rg_display = resource_grid_display
        self.ssb_masks = ssb_masks
        # ssb_info: list of (index, symbol_position)
        self.ssb_info = ssb_info
        self.ssb_indices = [i for i, _ in ssb_info]
        self.ssb_positions = [p for _, p in ssb_info]
        
        # Create main window
        self.root = tk.Tk()
        self.root.title(f"NR SSB Resource Grid Viewer - Case {SSB_CASE}, {CHANNEL_BW_MHZ} MHz, {SSB_SCS_KHZ} kHz SCS")
        self.root.geometry("1500x900")
        
        # Create notebook (tabbed interface)
        self.notebook = ttk.Notebook(self.root)
        self.notebook.pack(fill='both', expand=True, padx=5, pady=5)
        
        # Create tabs
        self.create_tab1_full_grid()
        self.create_tab2_ssb_structure()
        self.create_tab3_time_domain()
        self.create_tab4_constellation()
        
    def create_tab1_full_grid(self):
        """Tab 1: Full Resource Grid showing 5ms with all SSBs"""
        tab = ttk.Frame(self.notebook)
        self.notebook.add(tab, text="Full Resource Grid (5ms)")
        
        # Create figure
        fig = Figure(figsize=(14, 8))
        ax = fig.add_subplot(1, 1, 1)
        
        # Custom colormap for SSB components
        # 0=Empty (gray), 1=PSS (red), 2=SSS (blue), 3=PBCH (green), 4=PBCH_DMRS (yellow)
        colors = ['#404040',   # 0: Empty - Dark gray
                  '#ff4444',   # 1: PSS - Red
                  '#4444ff',   # 2: SSS - Blue
                  '#44ff44',   # 3: PBCH - Green
                  '#ffff00']   # 4: PBCH DMRS - Yellow
        cmap = ListedColormap(colors)
        
        # Plot the full resource grid
        im = ax.imshow(self.rg_display.T, aspect='auto', origin='lower',
                       cmap=cmap, interpolation='nearest', vmin=0, vmax=4)
        
        # Add slot boundary lines
        for slot in range(NUM_SLOTS_IN_HALF_FRAME + 1):
            sym = slot * NUM_SYMBOLS_PER_SLOT
            ax.axvline(x=sym - 0.5, color='white', linewidth=0.5, alpha=0.5)
        
        # Add RB boundary lines (every 10 RBs for visibility)
        for rb in range(0, MAX_RBS + 1, 10):
            ax.axhline(y=rb * 12 - 0.5, color='white', linewidth=0.3, alpha=0.3)
        
        # Highlight SSB region
        ssb_rect = Rectangle(
            (-0.5, SSB_START_SC - 0.5),
            NUM_OFDM_SYMBOLS, SSB_NUM_SUBCARRIERS,
            linewidth=2, edgecolor='cyan', facecolor='none',
            linestyle='--', label=f'SSB Region ({SSB_NUM_RBS} RBs)'
        )
        ax.add_patch(ssb_rect)

        # Annotate offsetToPointA / k_SSB (15 kHz raster), derived start/end SC/RB
        offset_info = (
            f"offsetToPointA (RB@15k): {OFFSET_TO_POINTA_RB}\n"
            f"k_SSB (SC@15k): {K_SSB}\n"
            f"SSB start SC: {SSB_START_SC}\n"
            f"SSB end SC:   {SSB_START_SC + SSB_NUM_SUBCARRIERS - 1}\n"
            f"SSB start RB: {SSB_OFFSET_RB}"
        )
        ax.text(0.01, 0.99, offset_info, transform=ax.transAxes,
                fontsize=8, va='top', ha='left',
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.75, edgecolor='magenta'))
        
        # Mark each SSB burst (use short labels when many SSBs)
        ssb_label_step = max(1, len(self.ssb_indices) // 16)  # Show ~16 labels max
        for idx, ssb_sym in self.ssb_info:
            ssb_burst_rect = Rectangle(
                (ssb_sym - 0.5, SSB_START_SC - 0.5),
                SSB_NUM_SYMBOLS, SSB_NUM_SUBCARRIERS,
                linewidth=1.5, edgecolor='magenta', facecolor='none',
                alpha=0.8
            )
            ax.add_patch(ssb_burst_rect)
            # Label (only show every Nth label to avoid overlap)
            if idx % ssb_label_step == 0:
                ax.text(ssb_sym + SSB_NUM_SYMBOLS/2, SSB_START_SC + SSB_NUM_SUBCARRIERS + 5,
                        f'{idx}', ha='center', va='bottom', fontsize=7, color='magenta')
        
        ax.set_xlabel('OFDM Symbol Index', fontsize=10)
        ax.set_ylabel('Subcarrier Index', fontsize=10)
        ax.set_title(f'SSB Resource Grid - Full 5ms Half Frame\n'
                     f'Case {SSB_CASE}, {NUM_SSB_TRANSMITTED} SSBs, '
                     f'{CHANNEL_BW_MHZ} MHz Channel, {SSB_SCS_KHZ} kHz SCS', fontsize=11)
        
        # Slot labels on top (use short format to avoid overlap)
        ax_top = ax.twiny()
        ax_top.set_xlim(ax.get_xlim())
        # Show every Nth slot label based on total number of slots
        slot_label_step = max(1, NUM_SLOTS_IN_HALF_FRAME // 20)  # Show ~20 labels max
        slot_positions = [(i + 0.5) * NUM_SYMBOLS_PER_SLOT for i in range(0, NUM_SLOTS_IN_HALF_FRAME, slot_label_step)]
        ax_top.set_xticks(slot_positions)
        ax_top.set_xticklabels([f's{i}' for i in range(0, NUM_SLOTS_IN_HALF_FRAME, slot_label_step)], fontsize=7)
        ax_top.set_xlabel('Slot Index', fontsize=10)
        
        # RB labels on right
        ax_right = ax.twinx()
        ax_right.set_ylim(ax.get_ylim())
        rb_step = 10
        rb_positions = [i * 12 for i in range(0, MAX_RBS, rb_step)]
        ax_right.set_yticks(rb_positions)
        ax_right.set_yticklabels([f'RB{i}' for i in range(0, MAX_RBS, rb_step)], fontsize=8)
        ax_right.set_ylabel('Resource Block', fontsize=10)
        
        # Legend
        legend_elements = [
            Patch(facecolor='#404040', edgecolor='black', label='Empty'),
            Patch(facecolor='#ff4444', edgecolor='black', label='PSS'),
            Patch(facecolor='#4444ff', edgecolor='black', label='SSS'),
            Patch(facecolor='#44ff44', edgecolor='black', label='PBCH'),
            Patch(facecolor='#ffff00', edgecolor='black', label='PBCH DMRS'),
        ]
        ax.legend(handles=legend_elements, loc='upper right', fontsize=8)
        
        fig.tight_layout()
        
        # Embed in tkinter
        canvas = FigureCanvasTkAgg(fig, master=tab)
        canvas.draw()
        canvas.get_tk_widget().pack(fill='both', expand=True)
        
        # Add toolbar
        toolbar = NavigationToolbar2Tk(canvas, tab)
        toolbar.update()
    
    def create_tab2_ssb_structure(self):
        """Tab 3: SSB Structure Diagram"""
        tab = ttk.Frame(self.notebook)
        self.notebook.add(tab, text="SSB Structure Diagram")
        
        fig = Figure(figsize=(14, 8))
        ax = fig.add_subplot(1, 1, 1)
        
        # Create a schematic view of SSB structure
        # X-axis: 4 symbols, Y-axis: 240 subcarriers
        
        # Draw background grid
        ssb_grid = np.zeros((SSB_NUM_SYMBOLS, SSB_NUM_SUBCARRIERS))
        
        # Fill with component colors
        # Symbol 0: PSS
        ssb_grid[0, PSS_SSS_START_SC:PSS_SSS_END_SC] = 1
        
        # Symbol 1: PBCH + DMRS
        for sc in range(SSB_NUM_SUBCARRIERS):
            if sc % 4 == 0:
                ssb_grid[1, sc] = 4  # DMRS
            else:
                ssb_grid[1, sc] = 3  # PBCH
        
        # Symbol 2: SSS (center) + PBCH (sides) + DMRS + unused gaps
        # Structure: PBCH(0-47) | unused(48-55) | SSS(56-182) | unused(183-191) | PBCH(192-239)
        for sc in range(SSB_NUM_SUBCARRIERS):
            if PSS_SSS_START_SC <= sc < PSS_SSS_END_SC:
                ssb_grid[2, sc] = 2  # SSS
            elif sc < PBCH_LEFT_END_SC:  # Left PBCH (SC 0-47)
                if sc % 4 == 0:
                    ssb_grid[2, sc] = 4  # DMRS
                else:
                    ssb_grid[2, sc] = 3  # PBCH
            elif sc >= PBCH_RIGHT_START_SC:  # Right PBCH (SC 192-239)
                if sc % 4 == 0:
                    ssb_grid[2, sc] = 4  # DMRS
                else:
                    ssb_grid[2, sc] = 3  # PBCH
            # else: SC 48-55 and SC 183-191 remain 0 (unused)
        
        # Symbol 3: PBCH + DMRS
        for sc in range(SSB_NUM_SUBCARRIERS):
            if sc % 4 == 0:
                ssb_grid[3, sc] = 4  # DMRS
            else:
                ssb_grid[3, sc] = 3  # PBCH
        
        colors = ['#404040', '#ff4444', '#4444ff', '#44ff44', '#ffff00']
        cmap = ListedColormap(colors)
        
        im = ax.imshow(ssb_grid.T, aspect='auto', origin='lower',
                       cmap=cmap, interpolation='nearest', vmin=0, vmax=4)
        
        # Add grid lines
        for sym in range(SSB_NUM_SYMBOLS + 1):
            ax.axvline(x=sym - 0.5, color='black', linewidth=2)
        
        # Add RB boundaries
        for rb in range(SSB_NUM_RBS + 1):
            ax.axhline(y=rb * 12 - 0.5, color='black', linewidth=0.3, alpha=0.5)
        
        # Add annotations
        ax.text(0, SSB_NUM_SUBCARRIERS/2, 'PSS\n(127 SC)', ha='center', va='center', 
                fontsize=12, fontweight='bold', color='white')
        ax.text(1, SSB_NUM_SUBCARRIERS/2, 'PBCH\n+ DMRS', ha='center', va='center', 
                fontsize=12, fontweight='bold', color='black')
        ax.text(2, 120, 'SSS (127 SC)\n+ PBCH sides', ha='center', va='center', 
                fontsize=10, fontweight='bold', color='white')
        ax.text(3, SSB_NUM_SUBCARRIERS/2, 'PBCH\n+ DMRS', ha='center', va='center', 
                fontsize=12, fontweight='bold', color='black')
        
        # Mark PSS/SSS region
        ax.axhline(y=PSS_SSS_START_SC, color='white', linewidth=2, linestyle='--')
        ax.axhline(y=PSS_SSS_END_SC - 1, color='white', linewidth=2, linestyle='--')
        
        ax.set_xlabel('SSB Symbol Index', fontsize=12)
        ax.set_ylabel('Subcarrier Index (within SSB)', fontsize=12)
        ax.set_title('SSB (Synchronization Signal Block) Structure\n'
                     '3GPP TS 38.211 Section 7.4.3', fontsize=14)
        ax.set_xticks(range(SSB_NUM_SYMBOLS))
        ax.set_xticklabels(['Symbol 0\n(PSS)', 'Symbol 1\n(PBCH)', 
                           'Symbol 2\n(SSS+PBCH)', 'Symbol 3\n(PBCH)'])
        
        # Legend
        legend_elements = [
            Patch(facecolor='#ff4444', edgecolor='black', label='PSS (Primary Sync Signal)'),
            Patch(facecolor='#4444ff', edgecolor='black', label='SSS (Secondary Sync Signal)'),
            Patch(facecolor='#44ff44', edgecolor='black', label='PBCH (Broadcast Channel)'),
            Patch(facecolor='#ffff00', edgecolor='black', label='PBCH DMRS'),
            Patch(facecolor='#404040', edgecolor='black', label='Unused'),
        ]
        ax.legend(handles=legend_elements, loc='upper right', fontsize=10)
        
        # Add text box with SSB parameters
        info_text = (f"SSB Parameters:\n"
                     f"• Case: {SSB_CASE}\n"
                     f"• SCS: {SSB_SCS_KHZ} kHz\n"
                     f"• Size: {SSB_NUM_RBS} RBs × {SSB_NUM_SYMBOLS} symbols\n"
                     f"• PSS/SSS: 127 subcarriers (centered)\n"
                     f"• PBCH DMRS: Every 4th subcarrier\n"
                     f"• Cell ID: {N_CELL_ID}")
        ax.text(0.02, 0.98, info_text, transform=ax.transAxes, fontsize=9,
                verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
        
        fig.tight_layout()
        
        canvas = FigureCanvasTkAgg(fig, master=tab)
        canvas.draw()
        canvas.get_tk_widget().pack(fill='both', expand=True)
        
        toolbar = NavigationToolbar2Tk(canvas, tab)
        toolbar.update()
    
    def create_tab3_time_domain(self):
        """Tab 4: Time domain view showing SSB bursts over 5ms"""
        tab = ttk.Frame(self.notebook)
        self.notebook.add(tab, text="SSB Time Pattern")
        
        fig = Figure(figsize=(14, 8))
        
        # Plot 1: SSB presence per OFDM symbol
        ax1 = fig.add_subplot(2, 1, 1)
        
        # Create binary SSB presence indicator
        ssb_presence = np.zeros(NUM_OFDM_SYMBOLS)
        for _, ssb_sym in self.ssb_info:
            for offset in range(SSB_NUM_SYMBOLS):
                if ssb_sym + offset < NUM_OFDM_SYMBOLS:
                    ssb_presence[ssb_sym + offset] = 1
        
        # Color-coded bar for each symbol
        colors_per_symbol = []
        for sym in range(NUM_OFDM_SYMBOLS):
            if ssb_presence[sym]:
                colors_per_symbol.append('#ff6600')  # Orange for SSB
            else:
                colors_per_symbol.append('#404040')  # Gray for empty
        
        bars = ax1.bar(range(NUM_OFDM_SYMBOLS), np.ones(NUM_OFDM_SYMBOLS), 
                       color=colors_per_symbol, edgecolor='black', linewidth=0.5)
        
        # Add slot boundaries
        for slot in range(NUM_SLOTS_IN_HALF_FRAME + 1):
            ax1.axvline(x=slot * NUM_SYMBOLS_PER_SLOT - 0.5, color='blue', 
                        linewidth=1.5, linestyle='-')
        
        # Add labels for each SSB (use short labels when many SSBs)
        ssb_label_step = max(1, len(self.ssb_indices) // 16)  # Show ~16 labels max
        for idx, ssb_sym in self.ssb_info:
            if idx % ssb_label_step == 0:
                ax1.annotate(f'{idx}', xy=(ssb_sym + 1.5, 1.05), ha='center', 
                            fontsize=7, fontweight='bold', color='#ff6600')
        
        ax1.set_xlabel('OFDM Symbol Index', fontsize=10)
        ax1.set_ylabel('SSB Presence', fontsize=10)
        # Truncate symbol positions display if too many
        if len(self.ssb_info) <= 8:
            pos_str = str([pos for _, pos in self.ssb_info])
        else:
            pos_str = f'[{self.ssb_info[0][1]}, {self.ssb_info[1][1]}, ..., {self.ssb_info[-1][1]}]'
        ax1.set_title(f'SSB Burst Pattern in 5ms Half Frame\n'
                      f'Case {SSB_CASE}: {NUM_SSB_TRANSMITTED} SSBs at symbols {pos_str}', 
                      fontsize=11)
        ax1.set_xlim(-0.5, NUM_OFDM_SYMBOLS - 0.5)
        ax1.set_ylim(0, 1.2)
        ax1.set_yticks([])
        
        # Slot labels (use short format to avoid overlap)
        ax1_top = ax1.twiny()
        ax1_top.set_xlim(ax1.get_xlim())
        slot_label_step = max(1, NUM_SLOTS_IN_HALF_FRAME // 20)  # Show ~20 labels max
        slot_centers = [(i + 0.5) * NUM_SYMBOLS_PER_SLOT for i in range(0, NUM_SLOTS_IN_HALF_FRAME, slot_label_step)]
        ax1_top.set_xticks(slot_centers)
        ax1_top.set_xticklabels([f's{i}' for i in range(0, NUM_SLOTS_IN_HALF_FRAME, slot_label_step)], fontsize=7)
        
        legend_elements = [
            Patch(facecolor='#ff6600', edgecolor='black', label='SSB'),
            Patch(facecolor='#404040', edgecolor='black', label='Empty'),
        ]
        ax1.legend(handles=legend_elements, loc='upper right', fontsize=9)
        
        # Plot 2: SSB pattern diagram
        ax2 = fig.add_subplot(2, 1, 2)
        
        # Draw time axis
        time_axis_length = HALF_FRAME_MS
        ax2.axhline(y=0.5, color='black', linewidth=2)
        
        # Calculate time position for each symbol
        symbol_duration_ms = SLOT_DURATION_MS / NUM_SYMBOLS_PER_SLOT
        
        # Draw each SSB as a block (use short labels when many SSBs)
        ssb_label_step = max(1, len(self.ssb_indices) // 16)  # Show ~16 labels max
        for idx, ssb_sym in self.ssb_info:
            ssb_start_time = ssb_sym * symbol_duration_ms
            ssb_duration = SSB_NUM_SYMBOLS * symbol_duration_ms
            
            rect = Rectangle((ssb_start_time, 0.2), ssb_duration, 0.6,
                             facecolor='#ff6600', edgecolor='black', linewidth=1)
            ax2.add_patch(rect)
            
            # Label (only show every Nth label to avoid overlap)
            if idx % ssb_label_step == 0:
                ax2.text(ssb_start_time + ssb_duration/2, 0.5, f'{idx}',
                        ha='center', va='center', fontsize=8, fontweight='bold', color='white')
        
        # Draw slot boundaries (use short labels to avoid overlap)
        slot_label_step = max(1, NUM_SLOTS_IN_HALF_FRAME // 20)  # Show ~20 labels max
        for slot in range(NUM_SLOTS_IN_HALF_FRAME + 1):
            slot_time = slot * SLOT_DURATION_MS
            ax2.axvline(x=slot_time, color='blue', linewidth=1, linestyle='--', alpha=0.7)
            if slot < NUM_SLOTS_IN_HALF_FRAME and slot % slot_label_step == 0:
                ax2.text(slot_time + SLOT_DURATION_MS/2, 0.9, f's{slot}',
                        ha='center', va='center', fontsize=7, color='blue')
        
        ax2.set_xlim(-0.1, time_axis_length + 0.1)
        ax2.set_ylim(0, 1.1)
        ax2.set_xlabel('Time (ms)', fontsize=10)
        ax2.set_title(f'SSB Timing in 5ms Half Frame\n'
                      f'Slot Duration: {SLOT_DURATION_MS:.3f} ms, Symbol Duration: {symbol_duration_ms*1000:.1f} µs',
                      fontsize=11)
        ax2.set_yticks([])
        
        # Add info text
        info_text = (f"SSB Case {SSB_CASE} Pattern:\n"
                     f"• First symbols: {SSB_FIRST_SYMBOLS}\n"
                     f"• Slot pattern period: {SSB_CASE_PARAMS[SSB_CASE]['slot_pattern_period']} symbols\n"
                     f"• L_max: {L_MAX}\n"
                     f"• SSBs transmitted: {NUM_SSB_TRANSMITTED}")
        ax2.text(0.02, 0.85, info_text, transform=ax2.transAxes, fontsize=9,
                verticalalignment='top', bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.9))
        
        fig.tight_layout()
        
        canvas = FigureCanvasTkAgg(fig, master=tab)
        canvas.draw()
        canvas.get_tk_widget().pack(fill='both', expand=True)
        
        toolbar = NavigationToolbar2Tk(canvas, tab)
        toolbar.update()
    
    def create_tab4_constellation(self):
        """Tab 4: Constellation diagrams for SSB components"""
        tab = ttk.Frame(self.notebook)
        self.notebook.add(tab, text="Constellation")
        
        # Create figure with 2x3 subplot grid
        fig = Figure(figsize=(14, 9))
        
        # Generate SSB signal values for first SSB
        # PSS: BPSK (real values +1/-1)
        pss_symbols = PSS_SEQUENCE  # Already generated, 127 symbols
        
        # SSS: BPSK (real values +1/-1)
        sss_symbols = SSS_SEQUENCE  # Already generated, 127 symbols
        
        # PBCH: QPSK modulation (simplified)
        np.random.seed(N_CELL_ID)
        qpsk_constellation = np.array([1+1j, 1-1j, -1+1j, -1-1j]) / np.sqrt(2)
        # PBCH has 432 REs per SSB (after removing DMRS)
        pbch_symbols = qpsk_constellation[np.random.randint(0, 4, 432)]
        
        # PBCH DMRS: QPSK
        dmrs_symbols = generate_pbch_dmrs_sequence(N_CELL_ID, 0, L_MAX)
        
        # =====================================================================
        # Plot 1: PSS Constellation (BPSK)
        # =====================================================================
        ax1 = fig.add_subplot(2, 3, 1)
        ax1.scatter(np.real(pss_symbols), np.imag(pss_symbols), 
                   c='#ff4444', s=50, alpha=0.7, edgecolors='black', linewidth=0.5)
        ax1.axhline(y=0, color='gray', linewidth=0.5, linestyle='--')
        ax1.axvline(x=0, color='gray', linewidth=0.5, linestyle='--')
        ax1.set_xlim([-1.5, 1.5])
        ax1.set_ylim([-1.5, 1.5])
        ax1.set_aspect('equal')
        ax1.set_xlabel('In-phase (I)', fontsize=9)
        ax1.set_ylabel('Quadrature (Q)', fontsize=9)
        ax1.set_title(f'PSS Constellation (BPSK)\n{len(pss_symbols)} symbols', fontsize=10)
        ax1.grid(True, alpha=0.3)
        
        # =====================================================================
        # Plot 2: SSS Constellation (BPSK)
        # =====================================================================
        ax2 = fig.add_subplot(2, 3, 2)
        ax2.scatter(np.real(sss_symbols), np.imag(sss_symbols), 
                   c='#4444ff', s=50, alpha=0.7, edgecolors='black', linewidth=0.5)
        ax2.axhline(y=0, color='gray', linewidth=0.5, linestyle='--')
        ax2.axvline(x=0, color='gray', linewidth=0.5, linestyle='--')
        ax2.set_xlim([-1.5, 1.5])
        ax2.set_ylim([-1.5, 1.5])
        ax2.set_aspect('equal')
        ax2.set_xlabel('In-phase (I)', fontsize=9)
        ax2.set_ylabel('Quadrature (Q)', fontsize=9)
        ax2.set_title(f'SSS Constellation (BPSK)\n{len(sss_symbols)} symbols', fontsize=10)
        ax2.grid(True, alpha=0.3)
        
        # =====================================================================
        # Plot 3: PBCH Constellation (QPSK)
        # =====================================================================
        ax3 = fig.add_subplot(2, 3, 3)
        ax3.scatter(np.real(pbch_symbols), np.imag(pbch_symbols), 
                   c='#44ff44', s=30, alpha=0.6, edgecolors='black', linewidth=0.3)
        ax3.axhline(y=0, color='gray', linewidth=0.5, linestyle='--')
        ax3.axvline(x=0, color='gray', linewidth=0.5, linestyle='--')
        ax3.set_xlim([-1.2, 1.2])
        ax3.set_ylim([-1.2, 1.2])
        ax3.set_aspect('equal')
        ax3.set_xlabel('In-phase (I)', fontsize=9)
        ax3.set_ylabel('Quadrature (Q)', fontsize=9)
        ax3.set_title(f'PBCH Constellation (QPSK)\n{len(pbch_symbols)} symbols', fontsize=10)
        ax3.grid(True, alpha=0.3)
        
        # =====================================================================
        # Plot 4: PBCH DMRS Constellation (QPSK)
        # =====================================================================
        ax4 = fig.add_subplot(2, 3, 4)
        ax4.scatter(np.real(dmrs_symbols), np.imag(dmrs_symbols), 
                   c='#ffff00', s=30, alpha=0.7, edgecolors='black', linewidth=0.3)
        ax4.axhline(y=0, color='gray', linewidth=0.5, linestyle='--')
        ax4.axvline(x=0, color='gray', linewidth=0.5, linestyle='--')
        ax4.set_xlim([-1.2, 1.2])
        ax4.set_ylim([-1.2, 1.2])
        ax4.set_aspect('equal')
        ax4.set_xlabel('In-phase (I)', fontsize=9)
        ax4.set_ylabel('Quadrature (Q)', fontsize=9)
        ax4.set_title(f'PBCH DMRS Constellation (QPSK)\n{len(dmrs_symbols)} symbols', fontsize=10)
        ax4.grid(True, alpha=0.3)
        
        # =====================================================================
        # Plot 5: Combined Constellation (All SSB components)
        # =====================================================================
        ax5 = fig.add_subplot(2, 3, 5)
        
        # Plot each component with different colors
        ax5.scatter(np.real(pss_symbols), np.imag(pss_symbols), 
                   c='#ff4444', s=40, alpha=0.7, label='PSS', edgecolors='black', linewidth=0.3)
        ax5.scatter(np.real(sss_symbols), np.imag(sss_symbols), 
                   c='#4444ff', s=40, alpha=0.7, label='SSS', edgecolors='black', linewidth=0.3)
        ax5.scatter(np.real(pbch_symbols), np.imag(pbch_symbols), 
                   c='#44ff44', s=20, alpha=0.5, label='PBCH', edgecolors='none')
        ax5.scatter(np.real(dmrs_symbols), np.imag(dmrs_symbols), 
                   c='#ffff00', s=20, alpha=0.6, label='DMRS', edgecolors='black', linewidth=0.2)
        
        ax5.axhline(y=0, color='gray', linewidth=0.5, linestyle='--')
        ax5.axvline(x=0, color='gray', linewidth=0.5, linestyle='--')
        ax5.set_xlim([-1.5, 1.5])
        ax5.set_ylim([-1.5, 1.5])
        ax5.set_aspect('equal')
        ax5.set_xlabel('In-phase (I)', fontsize=9)
        ax5.set_ylabel('Quadrature (Q)', fontsize=9)
        ax5.set_title(f'Combined SSB Constellation\nAll components', fontsize=10)
        ax5.legend(loc='upper right', fontsize=8)
        ax5.grid(True, alpha=0.3)
        
        # =====================================================================
        # Plot 6: Info text box (styled to match plot size)
        # =====================================================================
        ax6 = fig.add_subplot(2, 3, 6)
        ax6.set_xlim([0, 1])
        ax6.set_ylim([0, 1])
        ax6.set_xticks([])
        ax6.set_yticks([])
        ax6.set_facecolor('#f5f5dc')  # Beige background
        
        # Add border
        for spine in ax6.spines.values():
            spine.set_edgecolor('#8b7355')
            spine.set_linewidth(2)
        
        ax6.set_title('SSB Signal Characteristics', fontsize=11, fontweight='bold', pad=10)
        
        # Info text content
        info_lines = [
            ('PSS (Primary Sync Signal)', '#ff4444'),
            ('  Modulation: BPSK (±1)', 'black'),
            ('  Length: 127 symbols', 'black'),
            ('  m-sequence based', 'black'),
            (f'  N_ID^(2) = {N_ID_2}', 'black'),
            ('', 'black'),
            ('SSS (Secondary Sync Signal)', '#4444ff'),
            ('  Modulation: BPSK (±1)', 'black'),
            ('  Length: 127 symbols', 'black'),
            ('  Gold sequence based', 'black'),
            (f'  N_ID^(1) = {N_ID_1}', 'black'),
            ('', 'black'),
            ('PBCH (Broadcast Channel)', '#228b22'),
            ('  Modulation: QPSK', 'black'),
            ('  REs per SSB: 432', 'black'),
            ('  Carries MIB', 'black'),
            ('', 'black'),
            ('PBCH DMRS', '#b8860b'),
            ('  Modulation: QPSK', 'black'),
            ('  Every 4th subcarrier', 'black'),
            ('  Gold sequence based', 'black'),
            ('', 'black'),
            (f'Cell ID: {N_CELL_ID}', '#800080'),
        ]
        
        y_pos = 0.92
        for text, color in info_lines:
            if text:
                fontweight = 'bold' if not text.startswith('  ') else 'normal'
                ax6.text(0.08, y_pos, text, transform=ax6.transAxes, fontsize=9,
                        color=color, fontweight=fontweight, fontfamily='monospace')
            y_pos -= 0.038
        
        fig.tight_layout()
        
        # Embed in tkinter
        canvas = FigureCanvasTkAgg(fig, master=tab)
        canvas.draw()
        canvas.get_tk_widget().pack(fill='both', expand=True)
        
        toolbar = NavigationToolbar2Tk(canvas, tab)
        toolbar.update()
    
    def run(self):
        """Start the GUI"""
        self.root.mainloop()


# =============================================================================
# Print Configuration Summary
# =============================================================================

print("\n" + "=" * 70)
print("SSB Configuration Summary")
print("=" * 70)
print(f"SSB Case:                      {SSB_CASE}")
print(f"Subcarrier Spacing:            {SSB_SCS_KHZ} kHz")
print(f"L_max:                         {L_MAX}")
print(f"SSB TX Bitmap:                 {SSB_TX_BITMAP_NORM}")
print(f"SSBs Transmitted:              {NUM_SSB_TRANSMITTED}")
print(f"SSB Starting Symbols:          {SSB_SYMBOL_POSITIONS}")
print(f"SSB Indices (bitmap order):    {SSB_SYMBOL_INDICES}")
print(f"offsetToPointA (RB@15k):       {OFFSET_TO_POINTA_RB}")
print(f"k_SSB (SC@15k):                {K_SSB}")
print("-" * 70)
print(f"Time Domain:")
print(f"  Half Frame Duration:         {HALF_FRAME_MS} ms")
print(f"  Slots in Half Frame:         {NUM_SLOTS_IN_HALF_FRAME}")
print(f"  Symbols per Slot:            {NUM_SYMBOLS_PER_SLOT}")
print(f"  Total Symbols (5ms):         {NUM_OFDM_SYMBOLS}")
print("-" * 70)
print(f"Frequency Domain:")
print(f"  Channel Bandwidth:           {CHANNEL_BW_MHZ} MHz")
print(f"  Max RBs in Channel:          {MAX_RBS}")
print(f"  SSB Size:                    {SSB_NUM_RBS} RBs ({SSB_NUM_SUBCARRIERS} subcarriers)")
print(f"  SSB Start RB:                {SSB_OFFSET_RB}")
print(f"  SSB Start Subcarrier:        {SSB_START_SC}")
print("-" * 70)
print(f"SSB Components (per burst):")
print(f"  PSS:                         {PSS_SSS_NUM_SUBCARRIERS} subcarriers (Symbol 0)")
print(f"  SSS:                         {PSS_SSS_NUM_SUBCARRIERS} subcarriers (Symbol 2)")
print(f"  PBCH:                        3 symbols with DMRS (every 4th SC)")
print("=" * 70)

# =============================================================================
# Launch Viewer
# =============================================================================

print("\n" + "=" * 70)
print("Launching SSB Resource Grid Viewer...")
print("=" * 70)

viewer = SSBResourceGridViewer(RESOURCE_GRID_DISPLAY, SSB_MASKS, SSB_ACTIVE)
viewer.run()

print("\n" + "=" * 70)
print("SSB Resource Grid Visualization Complete!")
print("=" * 70)

