Compare commits

...

3 Commits

Author SHA1 Message Date
Lucas Kent 19678d887e Fix clippy lints for sasl crate 2024-05-07 07:53:43 +10:00
Jonas Schäfer 384b366f5f Add Message::extract_payload function
This should simplify access to message payloads significantly.
2024-05-06 09:40:08 +00:00
Lucas Kent a291ab2e83 Remove an allocation in client::mechanisms::scram::Scram::initial 2024-05-06 08:25:24 +10:00
9 changed files with 95 additions and 39 deletions

View File

@ -204,6 +204,40 @@ impl Message {
pub fn get_best_subject(&self, preferred_langs: Vec<&str>) -> Option<(Lang, &Subject)> {
Message::get_best::<Subject>(&self.subjects, preferred_langs)
}
/// Try to extract the given payload type from the message's payloads.
///
/// Returns the first matching payload element as parsed struct or its
/// parse error. If no element matches, `Ok(None)` is returned. If an
/// element matches, but fails to parse, it is nontheless removed from
/// the message.
///
/// Elements which do not match the given type are not removed.
pub fn extract_payload<T: TryFrom<Element, Error = Error>>(
&mut self,
) -> Result<Option<T>, Error> {
let mut buf = Vec::with_capacity(self.payloads.len());
let mut iter = self.payloads.drain(..);
let mut result = Ok(None);
for item in &mut iter {
match T::try_from(item) {
Ok(v) => {
result = Ok(Some(v));
break;
}
Err(Error::TypeMismatch(_, _, residual)) => {
buf.push(residual);
}
Err(other) => {
result = Err(other);
break;
}
}
}
buf.extend(iter);
std::mem::swap(&mut buf, &mut self.payloads);
result
}
}
impl TryFrom<Element> for Message {
@ -460,4 +494,27 @@ mod tests {
let elem2 = message.into();
assert_eq!(elem1, elem2);
}
#[test]
fn test_extract_payload() {
use super::super::attention::Attention;
use super::super::pubsub::event::PubSubEvent;
#[cfg(not(feature = "component"))]
let elem: Element = "<message xmlns='jabber:client' to='coucou@example.org' type='chat'><attention xmlns='urn:xmpp:attention:0'/></message>".parse().unwrap();
#[cfg(feature = "component")]
let elem: Element = "<message xmlns='jabber:component:accept' to='coucou@example.org' type='chat'><attention xmlns='urn:xmpp:attention:0'/></message>".parse().unwrap();
let mut message = Message::try_from(elem).unwrap();
assert_eq!(message.payloads.len(), 1);
match message.extract_payload::<PubSubEvent>() {
Ok(None) => (),
other => panic!("unexpected result: {:?}", other),
};
assert_eq!(message.payloads.len(), 1);
match message.extract_payload::<Attention>() {
Ok(Some(_)) => (),
other => panic!("unexpected result: {:?}", other),
};
assert_eq!(message.payloads.len(), 0);
}
}

View File

@ -11,6 +11,7 @@ impl Anonymous {
///
/// It is recommended that instead you use a `Credentials` struct and turn it into the
/// requested mechanism using `from_credentials`.
#[allow(clippy::new_without_default)]
pub fn new() -> Anonymous {
Anonymous
}

View File

@ -51,7 +51,7 @@ impl<S: ScramProvider> Scram<S> {
password: password.into(),
client_nonce: generate_nonce()?,
state: ScramState::Init,
channel_binding: channel_binding,
channel_binding,
_marker: PhantomData,
})
}
@ -109,10 +109,10 @@ impl<S: ScramProvider> Mechanism for Scram<S> {
bare.extend(self.client_nonce.bytes());
let mut data = Vec::new();
data.extend(&gs2_header);
data.extend(bare.clone());
data.extend(&bare);
self.state = ScramState::SentInitialMessage {
initial_message: bare,
gs2_header: gs2_header,
gs2_header,
};
data
}
@ -130,9 +130,9 @@ impl<S: ScramProvider> Mechanism for Scram<S> {
let server_nonce = frame.get("r");
let salt = frame.get("s").and_then(|v| Base64.decode(v).ok());
let iterations = frame.get("i").and_then(|v| v.parse().ok());
let server_nonce = server_nonce.ok_or_else(|| MechanismError::NoServerNonce)?;
let salt = salt.ok_or_else(|| MechanismError::NoServerSalt)?;
let iterations = iterations.ok_or_else(|| MechanismError::NoServerIterations)?;
let server_nonce = server_nonce.ok_or(MechanismError::NoServerNonce)?;
let salt = salt.ok_or(MechanismError::NoServerSalt)?;
let iterations = iterations.ok_or(MechanismError::NoServerIterations)?;
// TODO: SASLprep
let mut client_final_message_bare = Vec::new();
client_final_message_bare.extend(b"c=");
@ -158,9 +158,9 @@ impl<S: ScramProvider> Mechanism for Scram<S> {
let mut client_final_message = Vec::new();
client_final_message.extend(&client_final_message_bare);
client_final_message.extend(b",p=");
client_final_message.extend(Base64.encode(&client_proof).bytes());
client_final_message.extend(Base64.encode(client_proof).bytes());
next_state = ScramState::GotServerData {
server_signature: server_signature,
server_signature,
};
ret = client_final_message;
}
@ -178,7 +178,7 @@ impl<S: ScramProvider> Mechanism for Scram<S> {
ScramState::GotServerData {
ref server_signature,
} => {
if let Some(sig) = frame.get("v").and_then(|v| Base64.decode(&v).ok()) {
if let Some(sig) = frame.get("v").and_then(|v| Base64.decode(v).ok()) {
if sig == *server_signature {
Ok(())
} else {

View File

@ -86,9 +86,9 @@ impl Secret {
) -> Secret {
Secret::Password(Password::Pbkdf2 {
method: method.into(),
salt: salt,
iterations: iterations,
data: data,
salt,
iterations,
data,
})
}
}
@ -135,7 +135,7 @@ fn xor_works() {
pub fn xor(a: &[u8], b: &[u8]) -> Vec<u8> {
assert_eq!(a.len(), b.len());
let mut ret = Vec::with_capacity(a.len());
for (a, b) in a.into_iter().zip(b) {
for (a, b) in a.iter().zip(b) {
ret.push(a ^ b);
}
ret
@ -149,11 +149,8 @@ pub fn parse_frame(frame: &[u8]) -> Result<HashMap<String, String>, FromUtf8Erro
let mut tmp = s.splitn(2, '=');
let key = tmp.next();
let val = tmp.next();
match (key, val) {
(Some(k), Some(v)) => {
ret.insert(k.to_owned(), v.to_owned());
}
_ => (),
if let (Some(k), Some(v)) = (key, val) {
ret.insert(k.to_owned(), v.to_owned());
}
}
Ok(ret)

View File

@ -14,7 +14,7 @@ use base64::{engine::general_purpose::STANDARD as Base64, Engine};
pub fn generate_nonce() -> Result<String, RngError> {
let mut data = [0u8; 32];
getrandom(&mut data)?;
Ok(Base64.encode(&data))
Ok(Base64.encode(data))
}
#[derive(Debug, PartialEq)]
@ -111,7 +111,7 @@ impl ScramProvider for Sha1 {
method.to_string(),
Self::name().to_string(),
))
} else if my_salt == &salt {
} else if my_salt == salt {
Err(DeriveError::IncorrectSalt)
} else if my_iterations == iterations {
Err(DeriveError::IncompatibleIterationCount(
@ -171,7 +171,7 @@ impl ScramProvider for Sha256 {
method.to_string(),
Self::name().to_string(),
))
} else if my_salt == &salt {
} else if my_salt == salt {
Err(DeriveError::IncorrectSalt)
} else if my_iterations == iterations {
Err(DeriveError::IncompatibleIterationCount(

View File

@ -30,8 +30,8 @@ impl Pbkdf2Sha1 {
let digest = Sha1::derive(&Password::Plain(password.to_owned()), salt, iterations)?;
Ok(Pbkdf2Sha1 {
salt: salt.to_vec(),
iterations: iterations,
digest: digest,
iterations,
digest,
})
}
}
@ -70,8 +70,8 @@ impl Pbkdf2Sha256 {
let digest = Sha256::derive(&Password::Plain(password.to_owned()), salt, iterations)?;
Ok(Pbkdf2Sha256 {
salt: salt.to_vec(),
iterations: iterations,
digest: digest,
iterations,
digest,
})
}
}

View File

@ -6,6 +6,7 @@ use getrandom::getrandom;
pub struct Anonymous;
impl Anonymous {
#[allow(clippy::new_without_default)]
pub fn new() -> Anonymous {
Anonymous
}

View File

@ -9,7 +9,7 @@ pub struct Plain<V: Validator<secret::Plain>> {
impl<V: Validator<secret::Plain>> Plain<V> {
pub fn new(validator: V) -> Plain<V> {
Plain {
validator: validator,
validator,
}
}
}
@ -24,12 +24,12 @@ impl<V: Validator<secret::Plain>> Mechanism for Plain<V> {
sp.next();
let username = sp
.next()
.ok_or_else(|| MechanismError::NoUsernameSpecified)?;
.ok_or(MechanismError::NoUsernameSpecified)?;
let username = String::from_utf8(username.to_vec())
.map_err(|_| MechanismError::ErrorDecodingUsername)?;
let password = sp
.next()
.ok_or_else(|| MechanismError::NoPasswordSpecified)?;
.ok_or(MechanismError::NoPasswordSpecified)?;
let password = String::from_utf8(password.to_vec())
.map_err(|_| MechanismError::ErrorDecodingPassword)?;
let ident = Identity::Username(username);

View File

@ -44,8 +44,8 @@ where
Scram {
name: format!("SCRAM-{}", S::name()),
state: ScramState::Init,
channel_binding: channel_binding,
provider: provider,
channel_binding,
provider,
_marker: PhantomData,
}
}
@ -108,9 +108,9 @@ where
}
let frame =
parse_frame(&rest).map_err(|_| MechanismError::CannotDecodeInitialMessage)?;
let username = frame.get("n").ok_or_else(|| MechanismError::NoUsername)?;
let username = frame.get("n").ok_or(MechanismError::NoUsername)?;
let identity = Identity::Username(username.to_owned());
let client_nonce = frame.get("r").ok_or_else(|| MechanismError::NoNonce)?;
let client_nonce = frame.get("r").ok_or(MechanismError::NoNonce)?;
let mut server_nonce = String::new();
server_nonce += client_nonce;
server_nonce +=
@ -125,12 +125,12 @@ where
buf.extend(pbkdf2.iterations().to_string().bytes());
ret = Response::Proceed(buf.clone());
next_state = ScramState::SentChallenge {
server_nonce: server_nonce,
identity: identity,
server_nonce,
identity,
salted_password: pbkdf2.digest().to_vec(),
initial_client_message: rest,
initial_server_message: buf,
gs2_header: gs2_header,
gs2_header,
};
}
ScramState::SentChallenge {
@ -151,8 +151,8 @@ where
client_final_message_bare.extend(Base64.encode(&cb_data).bytes());
client_final_message_bare.extend(b",r=");
client_final_message_bare.extend(server_nonce.bytes());
let client_key = S::hmac(b"Client Key", &salted_password)?;
let server_key = S::hmac(b"Server Key", &salted_password)?;
let client_key = S::hmac(b"Client Key", salted_password)?;
let server_key = S::hmac(b"Server Key", salted_password)?;
let mut auth_message = Vec::new();
auth_message.extend(initial_client_message);
auth_message.extend(b",");
@ -162,7 +162,7 @@ where
let stored_key = S::hash(&client_key);
let client_signature = S::hmac(&auth_message, &stored_key)?;
let client_proof = xor(&client_key, &client_signature);
let sent_proof = frame.get("p").ok_or_else(|| MechanismError::NoProof)?;
let sent_proof = frame.get("p").ok_or(MechanismError::NoProof)?;
let sent_proof = Base64
.decode(sent_proof)
.map_err(|_| MechanismError::CannotDecodeProof)?;
@ -172,7 +172,7 @@ where
let server_signature = S::hmac(&auth_message, &server_key)?;
let mut buf = Vec::new();
buf.extend(b"v=");
buf.extend(Base64.encode(&server_signature).bytes());
buf.extend(Base64.encode(server_signature).bytes());
ret = Response::Success(identity.clone(), buf);
next_state = ScramState::Done;
}