Files
CourseWork/core/dct.py

186 lines
6.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Модуль DCT стеганографии с квантованием (JPEG-style)
Оптимизирован для длинных сообщений
"""
import numpy as np
from PIL import Image
from scipy.fftpack import dct, idct
from .utils import text_to_bits, bits_to_text
BLOCK_SIZE = 8
# Более мягкая таблица квантования (качество ~70-80%)
QUANT_TABLE = np.array([
[8, 6, 5, 8, 12, 20, 26, 31],
[6, 6, 7, 10, 13, 29, 30, 28],
[7, 7, 8, 12, 20, 29, 35, 28],
[7, 9, 11, 15, 26, 44, 40, 31],
[9, 11, 19, 28, 34, 55, 52, 39],
[12, 18, 28, 32, 41, 52, 57, 46],
[25, 32, 39, 44, 52, 61, 60, 51],
[36, 46, 48, 49, 56, 50, 52, 50]
])
def _get_zigzag_order() -> list:
"""Zigzag порядок для матрицы 8x8."""
zigzag = []
for s in range(15):
if s % 2 == 0:
for i in range(min(s, 7), max(0, s - 7) - 1, -1):
j = s - i
if 0 <= i < 8 and 0 <= j < 8:
zigzag.append(i * 8 + j)
else:
for i in range(max(0, s - 7), min(s, 7) + 1):
j = s - i
if 0 <= i < 8 and 0 <= j < 8:
zigzag.append(i * 8 + j)
return zigzag
ZIGZAG_ORDER = _get_zigzag_order()
# Используем больше коэффициентов (1-50)
MID_FREQ_INDICES = ZIGZAG_ORDER[1:50]
def _dct_quantize(block: np.ndarray) -> np.ndarray:
"""DCT + квантование."""
dct_block = dct(dct(block, axis=0, norm='ortho'), axis=1, norm='ortho')
quantized = np.round(dct_block / QUANT_TABLE)
return quantized.astype(np.int32)
def _idct_dequantize(quantized: np.ndarray) -> np.ndarray:
"""Обратное квантование + IDCT."""
dct_block = quantized * QUANT_TABLE
block = idct(idct(dct_block, axis=0, norm='ortho'), axis=1, norm='ortho')
block = np.round(block)
return np.clip(block, 0, 255).astype(np.uint8)
def _extract_bits_from_block(block_quant: np.ndarray, max_bits: int) -> list:
"""Извлекает биты из квантованных коэффициентов."""
extracted = []
coeff_flat = block_quant.flatten()
for idx in MID_FREQ_INDICES:
if len(extracted) >= max_bits:
break
if idx >= len(coeff_flat):
break
bit = int(coeff_flat[idx]) & 1
extracted.append(bit)
return extracted
def _embed_bits_in_block(block_quant: np.ndarray, bits: list, bit_index: int) -> tuple:
"""Внедряет биты в квантованные коэффициенты."""
modified = block_quant.copy()
current_idx = bit_index
coeff_flat = modified.flatten()
for idx in MID_FREQ_INDICES:
if current_idx >= len(bits):
break
if idx >= len(coeff_flat):
break
new_val = (int(coeff_flat[idx]) & 0xFE) | bits[current_idx]
coeff_flat[idx] = new_val
current_idx += 1
return coeff_flat.reshape(BLOCK_SIZE, BLOCK_SIZE), current_idx
def encode_dct(image_path: str, message: str, output_path: str) -> bool:
"""Скрывает сообщение в изображении методом DCT."""
img = Image.open(image_path).convert('RGB')
pixels = np.array(img, dtype=np.float64)
height, width, channels = pixels.shape
msg_bytes = message.encode('utf-8')
msg_length = len(msg_bytes)
length_bits = format(msg_length, '032b')
message_bits = text_to_bits(message)
all_bits = length_bits + message_bits
bit_list = [int(b) for b in all_bits]
total_bits = len(bit_list)
blocks_per_row = width // BLOCK_SIZE
blocks_per_col = height // BLOCK_SIZE
max_bits = blocks_per_row * blocks_per_col * len(MID_FREQ_INDICES) * 3
if total_bits > max_bits:
print(f"Ошибка: сообщение слишком большое.")
print(f"Доступно: {max_bits}, требуется: {total_bits}")
return False
modified_pixels = pixels.copy()
bit_index = 0
for i in range(0, height - BLOCK_SIZE + 1, BLOCK_SIZE):
for j in range(0, width - BLOCK_SIZE + 1, BLOCK_SIZE):
if bit_index >= total_bits:
break
for c in range(3):
if bit_index >= total_bits:
break
block = pixels[i:i+BLOCK_SIZE, j:j+BLOCK_SIZE, c]
quant_block = _dct_quantize(block)
quant_block, bit_index = _embed_bits_in_block(quant_block, bit_list, bit_index)
new_channel = _idct_dequantize(quant_block)
modified_pixels[i:i+BLOCK_SIZE, j:j+BLOCK_SIZE, c] = new_channel
if bit_index >= total_bits:
break
result_img = Image.fromarray(modified_pixels.astype(np.uint8), mode='RGB')
result_img.save(output_path)
print(f"Успешно! Спрятано {total_bits} бит ({total_bits // 8} байт)")
return True
def decode_dct(image_path: str) -> str:
"""Извлекает сообщение из изображения методом DCT."""
img = Image.open(image_path).convert('RGB')
pixels = np.array(img, dtype=np.float64)
height, width, channels = pixels.shape
all_bits = []
for i in range(0, height - BLOCK_SIZE + 1, BLOCK_SIZE):
for j in range(0, width - BLOCK_SIZE + 1, BLOCK_SIZE):
for c in range(3):
channel_block = pixels[i:i+BLOCK_SIZE, j:j+BLOCK_SIZE, c]
quant_block = _dct_quantize(channel_block)
bits_from_block = _extract_bits_from_block(quant_block, len(MID_FREQ_INDICES))
all_bits.extend(bits_from_block)
if len(all_bits) < 32:
return ""
length_bits = all_bits[:32]
length_str = ''.join(str(b) for b in length_bits)
msg_length = int(length_str, 2)
if msg_length > (len(all_bits) - 32) // 8:
return ""
message_bits = all_bits[32:32 + msg_length * 8]
remainder = len(message_bits) % 8
if remainder != 0:
message_bits.extend([0] * (8 - remainder))
if len(message_bits) == 0:
return ""
bits_string = ''.join(str(b) for b in message_bits)
try:
return bits_to_text(bits_string)
except Exception as e:
print(f"Ошибка при декодировании: {e}")
return ""