1use std::{cell::RefCell, fmt, str};
4
5use crate::{
6 body::{EitherBody, MessageBody},
7 error::{Error, HttpError},
8 header::{self, TryIntoHeaderPair, TryIntoHeaderValue},
9 responses::{BoxedResponseHead, ResponseHead},
10 ConnectionType, Extensions, Response, StatusCode,
11};
12
13pub struct ResponseBuilder {
39 head: Option<BoxedResponseHead>,
40 err: Option<HttpError>,
41}
42
43impl ResponseBuilder {
44 #[inline]
53 pub fn new(status: StatusCode) -> Self {
54 ResponseBuilder {
55 head: Some(BoxedResponseHead::new(status)),
56 err: None,
57 }
58 }
59
60 #[inline]
69 pub fn status(&mut self, status: StatusCode) -> &mut Self {
70 if let Some(parts) = self.inner() {
71 parts.status = status;
72 }
73 self
74 }
75
76 pub fn insert_header(&mut self, header: impl TryIntoHeaderPair) -> &mut Self {
91 if let Some(parts) = self.inner() {
92 match header.try_into_pair() {
93 Ok((key, value)) => {
94 parts.headers.insert(key, value);
95 }
96 Err(err) => self.err = Some(err.into()),
97 };
98 }
99
100 self
101 }
102
103 pub fn append_header(&mut self, header: impl TryIntoHeaderPair) -> &mut Self {
119 if let Some(parts) = self.inner() {
120 match header.try_into_pair() {
121 Ok((key, value)) => parts.headers.append(key, value),
122 Err(err) => self.err = Some(err.into()),
123 };
124 }
125
126 self
127 }
128
129 #[inline]
131 pub fn reason(&mut self, reason: &'static str) -> &mut Self {
132 if let Some(parts) = self.inner() {
133 parts.reason = Some(reason);
134 }
135 self
136 }
137
138 #[inline]
140 pub fn keep_alive(&mut self) -> &mut Self {
141 if let Some(parts) = self.inner() {
142 parts.set_connection_type(ConnectionType::KeepAlive);
143 }
144 self
145 }
146
147 #[inline]
149 pub fn upgrade<V>(&mut self, value: V) -> &mut Self
150 where
151 V: TryIntoHeaderValue,
152 {
153 if let Some(parts) = self.inner() {
154 parts.set_connection_type(ConnectionType::Upgrade);
155 }
156
157 if let Ok(value) = value.try_into_value() {
158 self.insert_header((header::UPGRADE, value));
159 }
160
161 self
162 }
163
164 #[inline]
166 pub fn force_close(&mut self) -> &mut Self {
167 if let Some(parts) = self.inner() {
168 parts.set_connection_type(ConnectionType::Close);
169 }
170 self
171 }
172
173 #[inline]
175 pub fn no_chunking(&mut self, len: u64) -> &mut Self {
176 let mut buf = itoa::Buffer::new();
177 self.insert_header((header::CONTENT_LENGTH, buf.format(len)));
178
179 if let Some(parts) = self.inner() {
180 parts.no_chunking(true);
181 }
182 self
183 }
184
185 #[inline]
187 pub fn content_type<V>(&mut self, value: V) -> &mut Self
188 where
189 V: TryIntoHeaderValue,
190 {
191 if let Some(parts) = self.inner() {
192 match value.try_into_value() {
193 Ok(value) => {
194 parts.headers.insert(header::CONTENT_TYPE, value);
195 }
196 Err(err) => self.err = Some(err.into()),
197 };
198 }
199 self
200 }
201
202 pub fn body<B>(&mut self, body: B) -> Response<EitherBody<B>>
206 where
207 B: MessageBody + 'static,
208 {
209 match self.message_body(body) {
210 Ok(res) => res.map_body(|_, body| EitherBody::left(body)),
211 Err(err) => Response::from(err).map_body(|_, body| EitherBody::right(body)),
212 }
213 }
214
215 pub fn message_body<B>(&mut self, body: B) -> Result<Response<B>, Error> {
219 if let Some(err) = self.err.take() {
220 return Err(Error::new_http().with_cause(err));
221 }
222
223 let head = self.head.take().expect("cannot reuse response builder");
224
225 Ok(Response {
226 head,
227 body,
228 extensions: RefCell::new(Extensions::new()),
229 })
230 }
231
232 #[inline]
236 pub fn finish(&mut self) -> Response<EitherBody<()>> {
237 self.body(())
238 }
239
240 pub fn take(&mut self) -> ResponseBuilder {
242 ResponseBuilder {
243 head: self.head.take(),
244 err: self.err.take(),
245 }
246 }
247
248 fn inner(&mut self) -> Option<&mut ResponseHead> {
250 if self.err.is_some() {
251 return None;
252 }
253
254 self.head.as_deref_mut()
255 }
256}
257
258impl Default for ResponseBuilder {
259 fn default() -> Self {
260 Self::new(StatusCode::OK)
261 }
262}
263
264impl<B> From<Response<B>> for ResponseBuilder {
266 fn from(res: Response<B>) -> ResponseBuilder {
267 ResponseBuilder {
268 head: Some(res.head),
269 err: None,
270 }
271 }
272}
273
274impl<'a> From<&'a ResponseHead> for ResponseBuilder {
276 fn from(head: &'a ResponseHead) -> ResponseBuilder {
277 let mut msg = BoxedResponseHead::new(head.status);
278 msg.version = head.version;
279 msg.reason = head.reason;
280
281 for (k, v) in head.headers.iter() {
282 msg.headers.append(k.clone(), v.clone());
283 }
284
285 msg.no_chunking(!head.chunked());
286
287 ResponseBuilder {
288 head: Some(msg),
289 err: None,
290 }
291 }
292}
293
294impl fmt::Debug for ResponseBuilder {
295 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
296 let head = self.head.as_ref().unwrap();
297
298 let res = writeln!(
299 f,
300 "\nResponseBuilder {:?} {}{}",
301 head.version,
302 head.status,
303 head.reason.unwrap_or(""),
304 );
305 let _ = writeln!(f, " headers:");
306 for (key, val) in head.headers.iter() {
307 let _ = writeln!(f, " {:?}: {:?}", key, val);
308 }
309 res
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use bytes::Bytes;
316
317 use super::*;
318 use crate::header::{HeaderName, HeaderValue, CONTENT_TYPE};
319
320 #[test]
321 fn test_basic_builder() {
322 let resp = Response::build(StatusCode::OK)
323 .insert_header(("X-TEST", "value"))
324 .finish();
325 assert_eq!(resp.status(), StatusCode::OK);
326 }
327
328 #[test]
329 fn test_upgrade() {
330 let resp = Response::build(StatusCode::OK)
331 .upgrade("websocket")
332 .finish();
333 assert!(resp.upgrade());
334 assert_eq!(
335 resp.headers().get(header::UPGRADE).unwrap(),
336 HeaderValue::from_static("websocket")
337 );
338 }
339
340 #[test]
341 fn test_force_close() {
342 let resp = Response::build(StatusCode::OK).force_close().finish();
343 assert!(!resp.keep_alive());
344 }
345
346 #[test]
347 fn test_content_type() {
348 let resp = Response::build(StatusCode::OK)
349 .content_type("text/plain")
350 .body(Bytes::new());
351 assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "text/plain");
352
353 let resp = Response::build(StatusCode::OK)
354 .content_type(mime::TEXT_JAVASCRIPT)
355 .body(Bytes::new());
356 assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "text/javascript");
357 }
358
359 #[test]
360 fn test_into_builder() {
361 let mut resp: Response<_> = "test".into();
362 assert_eq!(resp.status(), StatusCode::OK);
363
364 resp.headers_mut().insert(
365 HeaderName::from_static("cookie"),
366 HeaderValue::from_static("cookie1=val100"),
367 );
368
369 let mut builder: ResponseBuilder = resp.into();
370 let resp = builder.status(StatusCode::BAD_REQUEST).finish();
371 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
372
373 let cookie = resp.headers().get_all("Cookie").next().unwrap();
374 assert_eq!(cookie.to_str().unwrap(), "cookie1=val100");
375 }
376
377 #[test]
378 fn response_builder_header_insert_kv() {
379 let mut res = Response::build(StatusCode::OK);
380 res.insert_header(("Content-Type", "application/octet-stream"));
381 let res = res.finish();
382
383 assert_eq!(
384 res.headers().get("Content-Type"),
385 Some(&HeaderValue::from_static("application/octet-stream"))
386 );
387 }
388
389 #[test]
390 fn response_builder_header_insert_typed() {
391 let mut res = Response::build(StatusCode::OK);
392 res.insert_header((header::CONTENT_TYPE, mime::APPLICATION_OCTET_STREAM));
393 let res = res.finish();
394
395 assert_eq!(
396 res.headers().get("Content-Type"),
397 Some(&HeaderValue::from_static("application/octet-stream"))
398 );
399 }
400
401 #[test]
402 fn response_builder_header_append_kv() {
403 let mut res = Response::build(StatusCode::OK);
404 res.append_header(("Content-Type", "application/octet-stream"));
405 res.append_header(("Content-Type", "application/json"));
406 let res = res.finish();
407
408 let headers: Vec<_> = res.headers().get_all("Content-Type").cloned().collect();
409 assert_eq!(headers.len(), 2);
410 assert!(headers.contains(&HeaderValue::from_static("application/octet-stream")));
411 assert!(headers.contains(&HeaderValue::from_static("application/json")));
412 }
413
414 #[test]
415 fn response_builder_header_append_typed() {
416 let mut res = Response::build(StatusCode::OK);
417 res.append_header((header::CONTENT_TYPE, mime::APPLICATION_OCTET_STREAM));
418 res.append_header((header::CONTENT_TYPE, mime::APPLICATION_JSON));
419 let res = res.finish();
420
421 let headers: Vec<_> = res.headers().get_all("Content-Type").cloned().collect();
422 assert_eq!(headers.len(), 2);
423 assert!(headers.contains(&HeaderValue::from_static("application/octet-stream")));
424 assert!(headers.contains(&HeaderValue::from_static("application/json")));
425 }
426}