use hound::{WavSpec, WavWriter};
use log::{debug, info, warn};
use mutter::{Model, ModelType};
use std::i16;
use symphonia::core::audio::SampleBuffer;
use symphonia::core::codecs::DecoderOptions;
use symphonia::core::meta::MetadataOptions;
use symphonia::core::{formats::FormatOptions, io::MediaSourceStream, probe::Hint};
#[tokio::test]
async fn test_whisper() {
let test_media_paths = vec![
"./data/cache/message-184888.mp4",
"./data/cache/message-186661.mp4",
];
for media_path in test_media_paths {
let output = whisper_rs(media_path, "baseEn").await;
println!("Output for {}: {}", media_path, output);
assert!(output.len() > 1);
assert!(!output.contains("unable to transcribe audio"));
}
}
pub async fn whisper_rs(media_path: &str, whisper_model: &str) -> String {
match convert_audio(media_path) {
Ok(converted_path) => {
info!("✅ Audio conversion completed for: {}", media_path);
let transcription_result = transcribe_audio(&converted_path, whisper_model).await;
match transcription_result {
Ok(transcription) => transcription,
Err(e) => format!("🚫 Transcription error: {:?}", e),
}
}
Err(e) => format!("🚫 Error converting audio: {}", e),
}
}
fn convert_audio(media_path: &str) -> Result<String, Box<dyn std::error::Error>> {
let output_path = format!("{}.wav", media_path);
match {
let this = &std::path::Path::new(output_path.as_str());
std::fs::metadata(this).is_ok()
} {
true => {
info!("✅ Conversion exists. Loading: {}", output_path);
Ok(output_path)
}
false => {
info!("🎵 Starting audio conversion from: {}", media_path);
let src = std::fs::File::open(media_path)?;
let mss = MediaSourceStream::new(Box::new(src), Default::default());
let mut hint = Hint::new();
hint.with_extension("mp4");
let format_options = FormatOptions::default();
let metadata_options = MetadataOptions::default();
let probed = symphonia::default::get_probe().format(
&hint,
mss,
&format_options,
&metadata_options,
)?;
let mut format = probed.format;
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec == symphonia::core::codecs::CODEC_TYPE_AAC)
.ok_or("No AAC track found")?;
let (channels, sample_rate, bits_per_sample) = get_audio_format(track);
let track_id = track.id.to_owned();
let mut decoder = create_decoder(track)?;
let wav_spec = WavSpec {
channels,
sample_rate,
bits_per_sample,
sample_format: hound::SampleFormat::Int,
};
let mut wav_writer = WavWriter::create(output_path.as_str(), wav_spec)?;
process_audio_packets(&mut format, &mut decoder, &mut wav_writer, track_id)?;
info!("✅ Conversion complete. Output saved to: {}", output_path);
Ok(output_path)
}
}
}
fn get_audio_format(track: &symphonia::core::formats::Track) -> (u16, u32, u16) {
let channels = track.codec_params.channels.map(|c| c.count()).unwrap_or(2) as u16;
let sample_rate = track.codec_params.sample_rate.unwrap_or(48000) as u32;
let bits_per_sample = track.codec_params.bits_per_coded_sample.unwrap_or(16) as u16;
info!("📈 Input audio: {} channels, {} Hz", channels, sample_rate);
(channels, sample_rate, bits_per_sample)
}
fn create_decoder(
track: &symphonia::core::formats::Track,
) -> Result<Box<dyn symphonia::core::codecs::Decoder>, String> {
let dec_opts: DecoderOptions = Default::default();
symphonia::default::get_codecs()
.make(&track.codec_params, &dec_opts)
.map_err(|e| format!("Unsupported codec: {}", e))
}
fn process_audio_packets(
format: &mut Box<dyn symphonia::core::formats::FormatReader>,
decoder: &mut Box<dyn symphonia::core::codecs::Decoder>,
wav_writer: &mut WavWriter<std::io::BufWriter<std::fs::File>>,
track_id: u32,
) -> Result<(), Box<dyn std::error::Error>> {
let mut packet_count = 0;
let mut total_samples = 0;
while let Ok(packet) = format.next_packet() {
packet_count += 1;
debug!(
"📦 Processing packet {}, size: {} bytes",
packet_count,
packet.data.len()
);
if packet.track_id() != track_id {
continue;
}
match decoder.decode(&packet) {
Ok(audio_buf) => {
let mut sample_buf =
SampleBuffer::<i16>::new(audio_buf.capacity() as u64, *audio_buf.spec());
sample_buf.copy_interleaved_ref(audio_buf);
for &sample in sample_buf.samples() {
wav_writer.write_sample(sample)?;
total_samples += 1;
}
debug!("🔊 Decoded {} samples", sample_buf.samples().len());
}
Err(e) => warn!("🚫 Error decoding packet: {:?}", e),
}
}
info!("📊 Total packets processed: {}", packet_count);
debug!("📘 Total samples written: {}", total_samples);
Ok(())
}
async fn transcribe_audio(
converted_audio_path: &str,
whisper_model: &str,
) -> Result<String, Box<dyn std::error::Error>> {
let mut wav_file = tokio::fs::File::open(converted_audio_path).await?;
let mut wav_data = Vec::new();
tokio::io::copy(&mut wav_file, &mut wav_data).await?;
let model = Model::download(&model_string_to_type(whisper_model)).unwrap();
match model.transcribe_audio(wav_data, true, true, None) {
Ok(transcription) => {
println!("🎙️ Transcription: {}", transcription.as_text());
println!("📜 SRT: {}", transcription.as_srt());
Ok(transcription.as_text().to_string())
}
Err(e) => Err(format!("transcribe_audio error {:?}", e).into()),
}
}
fn model_string_to_type(model_str: &str) -> ModelType {
let model = model_str.to_lowercase();
println!("model_str {}", model_str);
match model.trim() {
"tiny" => ModelType::Tiny,
"tinyen" => ModelType::TinyEn,
"medium" => ModelType::Medium,
"mediumen" => ModelType::MediumEn,
"baseen" => ModelType::BaseEn,
"base" => ModelType::Base,
"large" => ModelType::LargeV3,
_ => {
warn!(
"Your Whisper Modsel Type {} is unknown. Defaulting to Base model.",
model_str
);
ModelType::Base
}
}
}