mas_config/sections/
clients.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::ops::Deref;
8
9use mas_iana::oauth::OAuthClientAuthenticationMethod;
10use mas_jose::jwk::PublicJsonWebKeySet;
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize, de::Error};
13use serde_with::serde_as;
14use ulid::Ulid;
15use url::Url;
16
17use super::{ClientSecret, ClientSecretRaw, ConfigurationSection};
18
19#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
20#[serde(rename_all = "snake_case")]
21pub enum JwksOrJwksUri {
22    Jwks(PublicJsonWebKeySet),
23    JwksUri(Url),
24}
25
26impl From<PublicJsonWebKeySet> for JwksOrJwksUri {
27    fn from(jwks: PublicJsonWebKeySet) -> Self {
28        Self::Jwks(jwks)
29    }
30}
31
32/// Authentication method used by clients
33#[derive(JsonSchema, Serialize, Deserialize, Copy, Clone, Debug)]
34#[serde(rename_all = "snake_case")]
35pub enum ClientAuthMethodConfig {
36    /// `none`: No authentication
37    None,
38
39    /// `client_secret_basic`: `client_id` and `client_secret` used as basic
40    /// authorization credentials
41    ClientSecretBasic,
42
43    /// `client_secret_post`: `client_id` and `client_secret` sent in the
44    /// request body
45    ClientSecretPost,
46
47    /// `client_secret_basic`: a `client_assertion` sent in the request body and
48    /// signed using the `client_secret`
49    ClientSecretJwt,
50
51    /// `client_secret_basic`: a `client_assertion` sent in the request body and
52    /// signed by an asymmetric key
53    PrivateKeyJwt,
54}
55
56impl std::fmt::Display for ClientAuthMethodConfig {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        match self {
59            ClientAuthMethodConfig::None => write!(f, "none"),
60            ClientAuthMethodConfig::ClientSecretBasic => write!(f, "client_secret_basic"),
61            ClientAuthMethodConfig::ClientSecretPost => write!(f, "client_secret_post"),
62            ClientAuthMethodConfig::ClientSecretJwt => write!(f, "client_secret_jwt"),
63            ClientAuthMethodConfig::PrivateKeyJwt => write!(f, "private_key_jwt"),
64        }
65    }
66}
67
68/// An OAuth 2.0 client configuration
69#[serde_as]
70#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
71pub struct ClientConfig {
72    /// The client ID
73    #[schemars(
74        with = "String",
75        regex(pattern = r"^[0123456789ABCDEFGHJKMNPQRSTVWXYZ]{26}$"),
76        description = "A ULID as per https://github.com/ulid/spec"
77    )]
78    pub client_id: Ulid,
79
80    /// Authentication method used for this client
81    client_auth_method: ClientAuthMethodConfig,
82
83    /// Name of the `OAuth2` client
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub client_name: Option<String>,
86
87    /// The client secret, used by the `client_secret_basic`,
88    /// `client_secret_post` and `client_secret_jwt` authentication methods
89    #[schemars(with = "ClientSecretRaw")]
90    #[serde_as(as = "serde_with::TryFromInto<ClientSecretRaw>")]
91    #[serde(flatten)]
92    pub client_secret: Option<ClientSecret>,
93
94    /// The JSON Web Key Set (JWKS) used by the `private_key_jwt` authentication
95    /// method. Mutually exclusive with `jwks_uri`
96    #[serde(skip_serializing_if = "Option::is_none")]
97    pub jwks: Option<PublicJsonWebKeySet>,
98
99    /// The URL of the JSON Web Key Set (JWKS) used by the `private_key_jwt`
100    /// authentication method. Mutually exclusive with `jwks`
101    #[serde(skip_serializing_if = "Option::is_none")]
102    pub jwks_uri: Option<Url>,
103
104    /// List of allowed redirect URIs
105    #[serde(default, skip_serializing_if = "Vec::is_empty")]
106    pub redirect_uris: Vec<Url>,
107}
108
109impl ClientConfig {
110    fn validate(&self) -> Result<(), Box<figment::error::Error>> {
111        let auth_method = self.client_auth_method;
112        match self.client_auth_method {
113            ClientAuthMethodConfig::PrivateKeyJwt => {
114                if self.jwks.is_none() && self.jwks_uri.is_none() {
115                    let error = figment::error::Error::custom(
116                        "jwks or jwks_uri is required for private_key_jwt",
117                    );
118                    return Err(Box::new(error.with_path("client_auth_method")));
119                }
120
121                if self.jwks.is_some() && self.jwks_uri.is_some() {
122                    let error =
123                        figment::error::Error::custom("jwks and jwks_uri are mutually exclusive");
124                    return Err(Box::new(error.with_path("jwks")));
125                }
126
127                if self.client_secret.is_some() {
128                    let error = figment::error::Error::custom(
129                        "client_secret is not allowed with private_key_jwt",
130                    );
131                    return Err(Box::new(error.with_path("client_secret")));
132                }
133            }
134
135            ClientAuthMethodConfig::ClientSecretPost
136            | ClientAuthMethodConfig::ClientSecretBasic
137            | ClientAuthMethodConfig::ClientSecretJwt => {
138                if self.client_secret.is_none() {
139                    let error = figment::error::Error::custom(format!(
140                        "client_secret is required for {auth_method}"
141                    ));
142                    return Err(Box::new(error.with_path("client_auth_method")));
143                }
144
145                if self.jwks.is_some() {
146                    let error = figment::error::Error::custom(format!(
147                        "jwks is not allowed with {auth_method}"
148                    ));
149                    return Err(Box::new(error.with_path("jwks")));
150                }
151
152                if self.jwks_uri.is_some() {
153                    let error = figment::error::Error::custom(format!(
154                        "jwks_uri is not allowed with {auth_method}"
155                    ));
156                    return Err(Box::new(error.with_path("jwks_uri")));
157                }
158            }
159
160            ClientAuthMethodConfig::None => {
161                if self.client_secret.is_some() {
162                    let error = figment::error::Error::custom(
163                        "client_secret is not allowed with none authentication method",
164                    );
165                    return Err(Box::new(error.with_path("client_secret")));
166                }
167
168                if self.jwks.is_some() {
169                    let error = figment::error::Error::custom(
170                        "jwks is not allowed with none authentication method",
171                    );
172                    return Err(Box::new(error));
173                }
174
175                if self.jwks_uri.is_some() {
176                    let error = figment::error::Error::custom(
177                        "jwks_uri is not allowed with none authentication method",
178                    );
179                    return Err(Box::new(error));
180                }
181            }
182        }
183
184        Ok(())
185    }
186
187    /// Authentication method used for this client
188    #[must_use]
189    pub fn client_auth_method(&self) -> OAuthClientAuthenticationMethod {
190        match self.client_auth_method {
191            ClientAuthMethodConfig::None => OAuthClientAuthenticationMethod::None,
192            ClientAuthMethodConfig::ClientSecretBasic => {
193                OAuthClientAuthenticationMethod::ClientSecretBasic
194            }
195            ClientAuthMethodConfig::ClientSecretPost => {
196                OAuthClientAuthenticationMethod::ClientSecretPost
197            }
198            ClientAuthMethodConfig::ClientSecretJwt => {
199                OAuthClientAuthenticationMethod::ClientSecretJwt
200            }
201            ClientAuthMethodConfig::PrivateKeyJwt => OAuthClientAuthenticationMethod::PrivateKeyJwt,
202        }
203    }
204
205    /// Returns the client secret.
206    ///
207    /// If `client_secret_file` was given, the secret is read from that file.
208    ///
209    /// # Errors
210    ///
211    /// Returns an error when the client secret could not be read from file.
212    pub async fn client_secret(&self) -> anyhow::Result<Option<String>> {
213        Ok(match &self.client_secret {
214            Some(client_secret) => Some(client_secret.value().await?),
215            None => None,
216        })
217    }
218}
219
220/// List of OAuth 2.0/OIDC clients config
221#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)]
222#[serde(transparent)]
223pub struct ClientsConfig(#[schemars(with = "Vec::<ClientConfig>")] Vec<ClientConfig>);
224
225impl ClientsConfig {
226    /// Returns true if all fields are at their default values
227    pub(crate) fn is_default(&self) -> bool {
228        self.0.is_empty()
229    }
230}
231
232impl Deref for ClientsConfig {
233    type Target = Vec<ClientConfig>;
234
235    fn deref(&self) -> &Self::Target {
236        &self.0
237    }
238}
239
240impl IntoIterator for ClientsConfig {
241    type Item = ClientConfig;
242    type IntoIter = std::vec::IntoIter<ClientConfig>;
243
244    fn into_iter(self) -> Self::IntoIter {
245        self.0.into_iter()
246    }
247}
248
249impl ConfigurationSection for ClientsConfig {
250    const PATH: Option<&'static str> = Some("clients");
251
252    fn validate(
253        &self,
254        figment: &figment::Figment,
255    ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
256        for (index, client) in self.0.iter().enumerate() {
257            client.validate().map_err(|mut err| {
258                // Save the error location information in the error
259                err.metadata = figment.find_metadata(Self::PATH.unwrap()).cloned();
260                err.profile = Some(figment::Profile::Default);
261                err.path.insert(0, Self::PATH.unwrap().to_owned());
262                err.path.insert(1, format!("{index}"));
263                err
264            })?;
265        }
266
267        Ok(())
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use std::str::FromStr;
274
275    use figment::{
276        Figment, Jail,
277        providers::{Format, Yaml},
278    };
279    use tokio::{runtime::Handle, task};
280
281    use super::*;
282
283    #[tokio::test]
284    async fn load_config() {
285        task::spawn_blocking(|| {
286            Jail::expect_with(|jail| {
287                jail.create_file(
288                    "config.yaml",
289                    r#"
290                      clients:
291                        - client_id: 01GFWR28C4KNE04WG3HKXB7C9R
292                          client_auth_method: none
293                          redirect_uris:
294                            - https://exemple.fr/callback
295
296                        - client_id: 01GFWR32NCQ12B8Z0J8CPXRRB6
297                          client_auth_method: client_secret_basic
298                          client_secret_file: secret
299
300                        - client_id: 01GFWR3WHR93Y5HK389H28VHZ9
301                          client_auth_method: client_secret_post
302                          client_secret: c1!3n753c237
303
304                        - client_id: 01GFWR43R2ZZ8HX9CVBNW9TJWG
305                          client_auth_method: client_secret_jwt
306                          client_secret_file: secret
307
308                        - client_id: 01GFWR4BNFDCC4QDG6AMSP1VRR
309                          client_auth_method: private_key_jwt
310                          jwks:
311                            keys:
312                            - kid: "03e84aed4ef4431014e8617567864c4efaaaede9"
313                              kty: "RSA"
314                              alg: "RS256"
315                              use: "sig"
316                              e: "AQAB"
317                              n: "ma2uRyBeSEOatGuDpCiV9oIxlDWix_KypDYuhQfEzqi_BiF4fV266OWfyjcABbam59aJMNvOnKW3u_eZM-PhMCBij5MZ-vcBJ4GfxDJeKSn-GP_dJ09rpDcILh8HaWAnPmMoi4DC0nrfE241wPISvZaaZnGHkOrfN_EnA5DligLgVUbrA5rJhQ1aSEQO_gf1raEOW3DZ_ACU3qhtgO0ZBG3a5h7BPiRs2sXqb2UCmBBgwyvYLDebnpE7AotF6_xBIlR-Cykdap3GHVMXhrIpvU195HF30ZoBU4dMd-AeG6HgRt4Cqy1moGoDgMQfbmQ48Hlunv9_Vi2e2CLvYECcBw"
318
319                            - kid: "d01c1abe249269f72ef7ca2613a86c9f05e59567"
320                              kty: "RSA"
321                              alg: "RS256"
322                              use: "sig"
323                              e: "AQAB"
324                              n: "0hukqytPwrj1RbMYhYoepCi3CN5k7DwYkTe_Cmb7cP9_qv4ok78KdvFXt5AnQxCRwBD7-qTNkkfMWO2RxUMBdQD0ED6tsSb1n5dp0XY8dSWiBDCX8f6Hr-KolOpvMLZKRy01HdAWcM6RoL9ikbjYHUEW1C8IJnw3MzVHkpKFDL354aptdNLaAdTCBvKzU9WpXo10g-5ctzSlWWjQuecLMQ4G1mNdsR1LHhUENEnOvgT8cDkX0fJzLbEbyBYkdMgKggyVPEB1bg6evG4fTKawgnf0IDSPxIU-wdS9wdSP9ZCJJPLi5CEp-6t6rE_sb2dGcnzjCGlembC57VwpkUvyMw"
325                    "#,
326                )?;
327                jail.create_file("secret", r"c1!3n753c237")?;
328
329                let config = Figment::new()
330                    .merge(Yaml::file("config.yaml"))
331                    .extract_inner::<ClientsConfig>("clients")?;
332
333                assert_eq!(config.0.len(), 5);
334
335                assert_eq!(
336                    config.0[0].client_id,
337                    Ulid::from_str("01GFWR28C4KNE04WG3HKXB7C9R").unwrap()
338                );
339                assert_eq!(
340                    config.0[0].redirect_uris,
341                    vec!["https://exemple.fr/callback".parse().unwrap()]
342                );
343
344                assert_eq!(
345                    config.0[1].client_id,
346                    Ulid::from_str("01GFWR32NCQ12B8Z0J8CPXRRB6").unwrap()
347                );
348                assert_eq!(config.0[1].redirect_uris, Vec::new());
349
350                assert!(config.0[0].client_secret.is_none());
351                assert!(matches!(config.0[1].client_secret, Some(ClientSecret::File(ref p)) if p == "secret"));
352                assert!(matches!(config.0[2].client_secret, Some(ClientSecret::Value(ref v)) if v == "c1!3n753c237"));
353                assert!(matches!(config.0[3].client_secret, Some(ClientSecret::File(ref p)) if p == "secret"));
354                assert!(config.0[4].client_secret.is_none());
355
356                Handle::current().block_on(async move {
357                    assert_eq!(config.0[1].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
358                    assert_eq!(config.0[2].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
359                    assert_eq!(config.0[3].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
360                });
361
362                Ok(())
363            });
364        }).await.unwrap();
365    }
366}