1use penumbra_sdk_proto::custody::v1::{self as pb, AuthorizeResponse};
2use rand_core::OsRng;
3use serde::{Deserialize, Serialize};
4use serde_with::{formats::Uppercase, hex::Hex};
5use tokio::sync::OnceCell;
6use tonic::{async_trait, Request, Response, Status};
7
8use crate::{soft_kms, terminal::Terminal, threshold};
9
10mod encryption {
11 use anyhow::anyhow;
12 use chacha20poly1305::{
13 aead::{AeadInPlace, NewAead},
14 ChaCha20Poly1305,
15 };
16 use rand_core::CryptoRngCore;
17
18 #[derive(Clone, Copy)]
20 pub struct Password<'a>(&'a str);
21
22 impl<'a> Password<'a> {
23 pub fn new(password: &'a str) -> anyhow::Result<Self> {
25 anyhow::ensure!(password.len() < argon2::MAX_PWD_LEN, "password too long");
26 Ok(Self(password))
27 }
28 }
29
30 impl<'a> TryFrom<&'a str> for Password<'a> {
31 type Error = anyhow::Error;
32
33 fn try_from(value: &'a str) -> Result<Self, Self::Error> {
34 Self::new(value)
35 }
36 }
37
38 const SALT_SIZE: usize = 32;
40 const TAG_SIZE: usize = 16;
41 const KEY_SIZE: usize = 32;
42
43 fn derive_key(salt: &[u8; SALT_SIZE], password: Password<'_>) -> [u8; KEY_SIZE] {
44 let mut key = [0u8; KEY_SIZE];
45 argon2::Argon2::hash_password_into(
48 &argon2::Argon2::new(
50 argon2::Algorithm::Argon2id,
51 argon2::Version::V0x13,
52 argon2::Params::new(1 << 21, 1, 4, Some(KEY_SIZE))
53 .expect("the parameters should be valid"),
54 ),
55 password.0.as_bytes(),
56 salt,
57 &mut key,
58 )
59 .expect("password hashing should not fail with a small enough password");
60 key
61 }
62
63 pub fn encrypt(rng: &mut impl CryptoRngCore, password: Password<'_>, data: &[u8]) -> Vec<u8> {
64 let salt = {
70 let mut out = [0u8; SALT_SIZE];
71 rng.fill_bytes(&mut out);
72 out
73 };
74 let key = derive_key(&salt, password);
75
76 let mut ciphertext = Vec::with_capacity(TAG_SIZE + salt.len() + data.len());
77 ciphertext.extend_from_slice(&[0u8; TAG_SIZE]);
78 ciphertext.extend_from_slice(&salt);
79 ciphertext.extend_from_slice(&data);
80 let tag = ChaCha20Poly1305::new(&key.into())
81 .encrypt_in_place_detached(
82 &Default::default(),
83 &salt,
84 &mut ciphertext[TAG_SIZE + SALT_SIZE..],
85 )
86 .expect("XChaCha20Poly1305 encryption should not fail");
87 ciphertext[0..TAG_SIZE].copy_from_slice(&tag);
88 ciphertext
89 }
90
91 pub fn decrypt(password: Password<'_>, data: &[u8]) -> anyhow::Result<Vec<u8>> {
92 anyhow::ensure!(
93 data.len() >= TAG_SIZE + SALT_SIZE,
94 "provided ciphertext is too short"
95 );
96 let (header, message) = data.split_at(TAG_SIZE + SALT_SIZE);
97 let mut message = message.to_owned();
98 let tag = &header[..TAG_SIZE];
99 let salt = &header[TAG_SIZE..TAG_SIZE + SALT_SIZE];
100 let key = derive_key(
101 &salt.try_into().expect("salt is the right length"),
102 password,
103 );
104 ChaCha20Poly1305::new(&key.into())
105 .decrypt_in_place_detached(&Default::default(), &salt, &mut message, tag.into())
106 .map_err(|_| anyhow!("failed to decrypt ciphertext"))?;
107 Ok(message)
108 }
109
110 #[cfg(test)]
111 mod test {
112 use rand_core::OsRng;
113
114 use super::*;
115
116 #[test]
117 fn test_encryption_decryption_roundtrip() -> anyhow::Result<()> {
118 let password = "password".try_into()?;
119 let message = b"hello world";
120 let encrypted = encrypt(&mut OsRng, password, message);
121 let decrypted = decrypt(password, &encrypted)?;
122 assert_eq!(decrypted.as_slice(), message);
123 Ok(())
124 }
125
126 #[test]
127 fn test_encryption_fails_with_different_password() -> anyhow::Result<()> {
128 let password = "password".try_into()?;
129 let message = b"hello world";
130 let encrypted = encrypt(&mut OsRng, password, message);
131 let decrypted = decrypt("not password".try_into()?, &encrypted);
132 assert!(decrypted.is_err());
133 Ok(())
134 }
135 }
136}
137
138use encryption::{decrypt, encrypt};
139
140#[derive(Serialize, Deserialize)]
142pub enum InnerConfig {
143 SoftKms(soft_kms::Config),
144 Threshold(threshold::Config),
145}
146
147impl InnerConfig {
148 pub fn from_bytes(data: &[u8]) -> anyhow::Result<Self> {
149 Ok(serde_json::from_slice(data)?)
150 }
151
152 pub fn to_bytes(self) -> anyhow::Result<Vec<u8>> {
153 Ok(serde_json::to_vec(&self)?)
154 }
155}
156
157#[serde_as]
161#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
162pub struct Config {
163 #[serde_as(as = "Hex<Uppercase>")]
164 data: Vec<u8>,
165}
166
167impl Config {
168 pub fn create(password: &str, inner: InnerConfig) -> anyhow::Result<Self> {
170 let password = password.try_into()?;
171 Ok(Self {
172 data: encrypt(&mut OsRng, password, &inner.to_bytes()?),
173 })
174 }
175
176 fn decrypt(self, password: &str) -> anyhow::Result<InnerConfig> {
177 let decrypted_data = decrypt(password.try_into()?, &self.data)?;
178 Ok(InnerConfig::from_bytes(&decrypted_data)?)
179 }
180
181 pub fn convert_to_threshold(self, password: &str) -> anyhow::Result<Option<threshold::Config>> {
183 match self.decrypt(password)? {
184 InnerConfig::SoftKms(_) => Ok(None),
185 InnerConfig::Threshold(c) => Ok(Some(c)),
186 }
187 }
188}
189
190pub struct Encrypted<T> {
194 config: Config,
195 terminal: T,
196 inner: OnceCell<anyhow::Result<Box<dyn pb::custody_service_server::CustodyService>>>,
197}
198
199impl<T: Terminal + Clone + Send + Sync + 'static> Encrypted<T> {
200 pub fn new(config: Config, terminal: T) -> Self {
202 Self {
203 config,
204 terminal,
205 inner: Default::default(),
206 }
207 }
208
209 async fn get_inner(&self) -> Result<&dyn pb::custody_service_server::CustodyService, Status> {
210 Ok(self
211 .inner
212 .get_or_init(|| async {
213 let password = self.terminal.get_password().await?;
214
215 let inner = self.config.clone().decrypt(&password)?;
216 let out: Box<dyn pb::custody_service_server::CustodyService> = match inner {
217 InnerConfig::SoftKms(c) => Box::new(soft_kms::SoftKms::new(c)),
218 InnerConfig::Threshold(c) => {
219 Box::new(threshold::Threshold::new(c, self.terminal.clone()))
220 }
221 };
222 Ok(out)
223 })
224 .await
225 .as_ref()
226 .map_err(|e| Status::unauthenticated(format!("failed to initialize custody {e}")))?
227 .as_ref())
228 }
229}
230
231#[async_trait]
232impl<T: Terminal + Clone + Send + Sync + 'static> pb::custody_service_server::CustodyService
233 for Encrypted<T>
234{
235 async fn authorize(
236 &self,
237 request: Request<pb::AuthorizeRequest>,
238 ) -> Result<Response<AuthorizeResponse>, Status> {
239 self.get_inner().await?.authorize(request).await
240 }
241
242 async fn authorize_validator_definition(
243 &self,
244 request: Request<pb::AuthorizeValidatorDefinitionRequest>,
245 ) -> Result<Response<pb::AuthorizeValidatorDefinitionResponse>, Status> {
246 self.get_inner()
247 .await?
248 .authorize_validator_definition(request)
249 .await
250 }
251
252 async fn authorize_validator_vote(
253 &self,
254 request: Request<pb::AuthorizeValidatorVoteRequest>,
255 ) -> Result<Response<pb::AuthorizeValidatorVoteResponse>, Status> {
256 self.get_inner()
257 .await?
258 .authorize_validator_vote(request)
259 .await
260 }
261
262 async fn export_full_viewing_key(
263 &self,
264 request: Request<pb::ExportFullViewingKeyRequest>,
265 ) -> Result<Response<pb::ExportFullViewingKeyResponse>, Status> {
266 self.get_inner()
267 .await?
268 .export_full_viewing_key(request)
269 .await
270 }
271
272 async fn confirm_address(
273 &self,
274 request: Request<pb::ConfirmAddressRequest>,
275 ) -> Result<Response<pb::ConfirmAddressResponse>, Status> {
276 self.get_inner().await?.confirm_address(request).await
277 }
278}