1use axum::{extract::{Path, State, FromRequestParts}, http::{Request, StatusCode}, middleware::Next, body::Body, response::{Response, IntoResponse}, Json};
2use util::state::AppState;
3use crate::auth::claims::AuthUser;
4use crate::response::ApiResponse;
5use sea_orm::DatabaseConnection;
6use std::collections::HashMap;
7use sea_orm::EntityTrait;
8use sea_orm::QueryFilter;
9use sea_orm::ColumnTrait;
10use db::models::{
11 module::Entity as ModuleEntity,
12 plagiarism_case::{Entity as PlagiarismEntity, Column as PlagiarismColumn},
13 assignment::{Entity as AssignmentEntity, Column as AssignmentColumn},
14 assignment_task::{Entity as TaskEntity, Column as TaskColumn},
15 assignment_submission::{Entity as SubmissionEntity, Column as SubmissionColumn},
16 assignment_file::{Entity as FileEntity, Column as FileColumn},
17 user::Entity as UserEntity,
18 user
19};
20
21#[derive(serde::Serialize, Default)]
24pub struct Empty;
25
26async fn extract_and_insert_authuser(
28 mut req: Request<Body>
29) -> Result<(Request<Body>, AuthUser), (StatusCode, Json<ApiResponse<Empty>>)> {
30 let (mut parts, body) = req.into_parts();
31 let user = AuthUser::from_request_parts(&mut parts, &())
32 .await
33 .map_err(|_| (
34 StatusCode::UNAUTHORIZED,
35 Json(ApiResponse::error("Authentication required"))
36 ))?;
37
38 req = Request::from_parts(parts, body);
39 req.extensions_mut().insert(user.clone());
40 Ok((req, user))
41}
42
43async fn user_has_any_role(
45 db: &DatabaseConnection,
46 user_id: i64,
47 module_id: i64,
48 roles: &[&str],
49) -> bool {
50 for role in roles {
51 if user::Model::is_in_role(db, user_id, module_id, role).await.unwrap_or(false) {
52 return true;
53 }
54 }
55 false
56}
57
58pub async fn require_authenticated(
60 req: Request<Body>,
61 next: Next,
62) -> Result<Response, (StatusCode, Json<ApiResponse<Empty>>)> {
63 let (req, _user) = extract_and_insert_authuser(req).await?;
64
65 Ok(next.run(req).await)
66}
67
68pub async fn require_admin(
70 req: Request<Body>,
71 next: Next,
72) -> Result<Response, (StatusCode, Json<ApiResponse<Empty>>)> {
73 let (req, user) = extract_and_insert_authuser(req).await?;
74
75 if !user.0.admin {
76 return Err((
77 StatusCode::FORBIDDEN,
78 Json(ApiResponse::error("Admin access required"))
79 ));
80 }
81
82 Ok(next.run(req).await)
83}
84
85async fn require_role_base(
87 State(app_state): State<AppState>,
88 Path(params): Path<HashMap<String, String>>,
89 req: Request<Body>,
90 next: Next,
91 required_roles: &[&str],
92 failure_msg: &str,
93) -> Result<Response, (StatusCode, Json<ApiResponse<Empty>>)> {
94 let db: &DatabaseConnection = app_state.db();
95
96 let (req, user) = extract_and_insert_authuser(req).await?;
97
98 let module_id = params.get("module_id")
99 .and_then(|s| s.parse::<i64>().ok())
100 .ok_or((
101 StatusCode::BAD_REQUEST,
102 Json(ApiResponse::error("Missing or invalid module_id"))
103 ))?;
104
105 if user.0.admin {
106 return Ok(next.run(req).await);
107 }
108
109 if user_has_any_role(&db, user.0.sub, module_id, required_roles).await {
110 Ok(next.run(req).await)
111 } else {
112 Err((StatusCode::FORBIDDEN, Json(ApiResponse::error(failure_msg))))
113 }
114}
115
116pub async fn require_lecturer(
118 State(app_state): State<AppState>,
119 Path(params): Path<HashMap<String, String>>,
120 req: Request<Body>,
121 next: Next,
122) -> Result<Response, (StatusCode, Json<ApiResponse<Empty>>)> {
123 require_role_base(
124 State(app_state),
125 Path(params),
126 req,
127 next,
128 &["Lecturer"],
129 "Lecturer access required for this module"
130 ).await
131}
132
133pub async fn require_assistant_lecturer(
135 State(app_state): State<AppState>,
136 Path(params): Path<HashMap<String, String>>,
137 req: Request<Body>,
138 next: Next,
139) -> Result<Response, (StatusCode, Json<ApiResponse<Empty>>)> {
140 require_role_base(
141 State(app_state),
142 Path(params),
143 req,
144 next,
145 &["AssistantLecturer"],
146 "Assistant lecturer access required for this module"
147 ).await
148}
149
150pub async fn require_tutor(
152 State(app_state): State<AppState>,
153 Path(params): Path<HashMap<String, String>>,
154 req: Request<Body>,
155 next: Next,
156) -> Result<Response, (StatusCode, Json<ApiResponse<Empty>>)> {
157 require_role_base(
158 State(app_state),
159 Path(params),
160 req,
161 next,
162 &["Tutor"],
163 "Tutor access required for this module"
164 ).await
165}
166
167pub async fn require_student(
169 State(app_state): State<AppState>,
170 Path(params): Path<HashMap<String, String>>,
171 req: Request<Body>,
172 next: Next,
173) -> Result<Response, (StatusCode, Json<ApiResponse<Empty>>)> {
174 require_role_base(
175 State(app_state),
176 Path(params),
177 req,
178 next,
179 &["Student"],
180 "Student access required for this module"
181 ).await
182}
183
184pub async fn require_lecturer_or_assistant_lecturer(
186 State(app_state): State<AppState>,
187 Path(params): Path<HashMap<String, String>>,
188 req: Request<Body>,
189 next: Next,
190) -> Result<Response, (StatusCode, Json<ApiResponse<Empty>>)> {
191 require_role_base(
192 State(app_state),
193 Path(params),
194 req,
195 next,
196 &["Lecturer", "AssistantLecturer"],
197 "Lecturer or assistant lecturer access required for this module"
198 ).await
199}
200
201pub async fn require_lecturer_or_tutor(
204 State(app_state): State<AppState>,
205 Path(params): Path<HashMap<String, String>>,
206 req: Request<Body>,
207 next: Next,
208) -> Result<Response, (StatusCode, Json<ApiResponse<Empty>>)> {
209 require_role_base(
210 State(app_state),
211 Path(params),
212 req,
213 next,
214 &["Lecturer", "Tutor"],
215 "Lecturer or tutor access required for this module"
216 ).await
217}
218
219pub async fn require_assigned_to_module(
221 State(app_state): State<AppState>,
222 Path(params): Path<HashMap<String, String>>,
223 req: Request<Body>,
224 next: Next,
225) -> Result<Response, (StatusCode, Json<ApiResponse<Empty>>)> {
226 require_role_base(
227 State(app_state),
228 Path(params),
229 req,
230 next,
231 &["Lecturer", "AssistantLecturer", "Tutor", "Student"],
232 "User not assigned to this module"
233 ).await
234}
235
236pub async fn require_ready_assignment(
237 State(app_state): State<AppState>,
238 Path(params): Path<HashMap<String, String>>,
239 req: Request<Body>,
240 next: Next,
241) -> Result<Response, (StatusCode, Json<ApiResponse<Empty>>)> {
242 let db = app_state.db();
243
244 let module_id = params.get("module_id")
245 .and_then(|s| s.parse::<i64>().ok())
246 .ok_or((
247 StatusCode::BAD_REQUEST,
248 Json(ApiResponse::error("Missing or invalid module_id"))
249 ))?;
250
251 let assignment_id = params.get("assignment_id")
252 .and_then(|s| s.parse::<i64>().ok())
253 .ok_or((
254 StatusCode::BAD_REQUEST,
255 Json(ApiResponse::error("Missing or invalid assignment_id"))
256 ))?;
257
258 if let Err(e) = db::models::assignment::Model::try_transition_to_ready(db, module_id, assignment_id).await {
259 return Err((
260 StatusCode::INTERNAL_SERVER_ERROR,
261 Json(ApiResponse::error(format!("Failed to transition assignment to ready: {}", e)))
262 ));
263 }
264
265 let assignment = match AssignmentEntity::find_by_id(assignment_id).one(db).await {
266 Ok(Some(a)) => a,
267 Ok(None) => {
268 return Err((
269 StatusCode::NOT_FOUND,
270 Json(ApiResponse::error(format!(
271 "Assignment {} in Module {} not found.",
272 assignment_id, module_id
273 ))),
274 ));
275 }
276 Err(_) => {
277 return Err((
278 StatusCode::INTERNAL_SERVER_ERROR,
279 Json(ApiResponse::error("Database error while checking assignment")),
280 ));
281 }
282 };
283
284 if assignment.status == db::models::assignment::Status::Setup {
285 return Err((
286 StatusCode::FORBIDDEN,
287 Json(ApiResponse::error("Assignment is still in Setup stage"))
288 ));
289 }
290
291 Ok(next.run(req).await)
292}
293
294async fn check_module_exists(
297 module_id: i32,
298 db: &DatabaseConnection,
299) -> Result<(), (StatusCode, Json<ApiResponse<Empty>>)> {
300 let found = ModuleEntity::find_by_id(module_id)
301 .one(db)
302 .await
303 .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Database error while checking module"))))?;
304
305 if found.is_none() {
306 return Err((
307 StatusCode::NOT_FOUND,
308 Json(ApiResponse::error(format!("Module {} not found.", module_id))),
309 ));
310 }
311 Ok(())
312}
313
314async fn check_user_exists(
315 user_id: i32,
316 db: &DatabaseConnection,
317) -> Result<(), (StatusCode, Json<ApiResponse<Empty>>)> {
318 let found = UserEntity::find_by_id(user_id)
319 .one(db)
320 .await
321 .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Database error while checking user"))))?;
322
323 if found.is_none() {
324 return Err((
325 StatusCode::NOT_FOUND,
326 Json(ApiResponse::error(format!("User {} not found.", user_id))),
327 ));
328 }
329 Ok(())
330}
331
332async fn check_assignment_hierarchy(
333 module_id: i32,
334 assignment_id: i32,
335 db: &DatabaseConnection,
336) -> Result<(), (StatusCode, Json<ApiResponse<Empty>>)> {
337 check_module_exists(module_id, db).await?;
338
339 let found = AssignmentEntity::find()
340 .filter(AssignmentColumn::Id.eq(assignment_id))
341 .filter(AssignmentColumn::ModuleId.eq(module_id))
342 .one(db)
343 .await
344 .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Database error while checking assignment"))))?;
345
346 if found.is_none() {
347 return Err((
348 StatusCode::NOT_FOUND,
349 Json(ApiResponse::error(format!("Assignment {} in Module {} not found.", assignment_id, module_id))),
350 ));
351 }
352 Ok(())
353}
354
355async fn check_task_hierarchy(
356 module_id: i32,
357 assignment_id: i32,
358 task_id: i32,
359 db: &DatabaseConnection,
360) -> Result<(), (StatusCode, Json<ApiResponse<Empty>>)> {
361 check_assignment_hierarchy(module_id, assignment_id, db).await?;
362
363 let found = TaskEntity::find()
364 .filter(TaskColumn::Id.eq(task_id))
365 .filter(TaskColumn::AssignmentId.eq(assignment_id))
366 .one(db)
367 .await
368 .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Database error while checking task"))))?;
369
370 if found.is_none() {
371 return Err((
372 StatusCode::NOT_FOUND,
373 Json(ApiResponse::error(format!("Task {} in Assignment {} not found.", task_id, assignment_id))),
374 ));
375 }
376 Ok(())
377}
378
379async fn check_submission_hierarchy(
380 module_id: i32,
381 assignment_id: i32,
382 submission_id: i32,
383 db: &DatabaseConnection,
384) -> Result<(), (StatusCode, Json<ApiResponse<Empty>>)> {
385 check_assignment_hierarchy(module_id, assignment_id, db).await?;
386
387 let found = SubmissionEntity::find()
388 .filter(SubmissionColumn::Id.eq(submission_id))
389 .filter(SubmissionColumn::AssignmentId.eq(assignment_id))
390 .one(db)
391 .await
392 .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Database error while checking submission"))))?;
393
394 if found.is_none() {
395 return Err((
396 StatusCode::NOT_FOUND,
397 Json(ApiResponse::error(format!("Submission {} in Assignment {} not found.", submission_id, assignment_id))),
398 ));
399 }
400 Ok(())
401}
402
403async fn check_file_hierarchy(
404 module_id: i32,
405 assignment_id: i32,
406 file_id: i32,
407 db: &DatabaseConnection,
408) -> Result<(), (StatusCode, Json<ApiResponse<Empty>>)> {
409 check_assignment_hierarchy(module_id, assignment_id, db).await?;
410
411 let found = FileEntity::find()
412 .filter(FileColumn::Id.eq(file_id))
413 .filter(FileColumn::AssignmentId.eq(assignment_id))
414 .one(db)
415 .await
416 .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Database error while checking file"))))?;
417
418 if found.is_none() {
419 return Err((
420 StatusCode::NOT_FOUND,
421 Json(ApiResponse::error(format!("File {} in Assignment {} not found.", file_id, assignment_id))),
422 ));
423 }
424 Ok(())
425}
426
427pub async fn check_ticket_hierarchy(
428 module_id: i32,
429 assignment_id: i32,
430 ticket_id: i32,
431 db: &DatabaseConnection,
432) -> Result<(), (StatusCode, Json<ApiResponse<Empty>>)> {
433 check_assignment_hierarchy(module_id, assignment_id, db).await?;
434
435 let found = db::models::tickets::Entity::find()
436 .filter(db::models::tickets::Column::Id.eq(ticket_id))
437 .filter(db::models::tickets::Column::AssignmentId.eq(assignment_id))
438 .one(db)
439 .await
440 .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Database error while checking ticket"))))?;
441
442 if found.is_none() {
443 return Err((
444 StatusCode::NOT_FOUND,
445 Json(ApiResponse::error(format!("Ticket {} in Assignment {} not found.", ticket_id, assignment_id))),
446 ));
447 }
448 Ok(())
449}
450
451pub async fn check_message_hierarchy(
452 module_id: i32,
453 assignment_id: i32,
454 ticket_id: i32,
455 message_id: i32,
456 db: &DatabaseConnection,
457) -> Result<(), (StatusCode, Json<ApiResponse<Empty>>)> {
458 check_ticket_hierarchy(module_id, assignment_id, ticket_id, db).await?;
459
460 let found = db::models::ticket_messages::Entity::find()
461 .filter(db::models::ticket_messages::Column::Id.eq(message_id))
462 .filter(db::models::ticket_messages::Column::TicketId.eq(ticket_id))
463 .one(db)
464 .await
465 .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Database error while checking message"))))?;
466
467 if found.is_none() {
468 return Err((
469 StatusCode::NOT_FOUND,
470 Json(ApiResponse::error(format!("Message {} in Ticket {} not found.", message_id, ticket_id))),
471 ));
472 }
473 Ok(())
474}
475
476pub async fn check_plagiarism_hierarchy(
477 module_id: i32,
478 assignment_id: i32,
479 case_id: i32,
480 db: &DatabaseConnection,
481) -> Result<(), (StatusCode, Json<ApiResponse<Empty>>)> {
482 check_assignment_hierarchy(module_id, assignment_id, db).await?;
483
484 let found = PlagiarismEntity::find()
485 .filter(PlagiarismColumn::Id.eq(case_id))
486 .filter(PlagiarismColumn::AssignmentId.eq(assignment_id))
487 .one(db)
488 .await
489 .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Database error while checking plagiarism case"))))?;
490
491 if found.is_none() {
492 return Err((
493 StatusCode::NOT_FOUND,
494 Json(ApiResponse::error(format!("Plagiarism case {} in Assignment {} not found.", case_id, assignment_id))),
495 ));
496 }
497 Ok(())
498}
499
500pub async fn check_announcement_hierarchy(
501 module_id: i32,
502 announcement_id: i32,
503 db: &DatabaseConnection,
504) -> Result<(), (StatusCode, Json<ApiResponse<Empty>>)> {
505 check_module_exists(module_id, db).await?;
506
507 let found = db::models::announcements::Entity::find()
508 .filter(db::models::announcements::Column::Id.eq(announcement_id))
509 .filter(db::models::announcements::Column::ModuleId.eq(module_id))
510 .one(db)
511 .await
512 .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Database error while checking announcement"))))?;
513
514 if found.is_none() {
515 return Err((
516 StatusCode::NOT_FOUND,
517 Json(ApiResponse::error(format!("Announcement {} in Module {} not found.", announcement_id, module_id))),
518 ));
519 }
520 Ok(())
521}
522
523pub async fn validate_known_ids(
524 State(app_state): State<AppState>,
525 Path(params): Path<HashMap<String, String>>,
526 req: Request<Body>,
527 next: Next,
528) -> Result<Response, Response> {
529 let db = app_state.db();
530
531 let mut module_id: Option<i32> = None;
532 let mut assignment_id: Option<i32> = None;
533 let mut task_id: Option<i32> = None;
534 let mut submission_id: Option<i32> = None;
535 let mut file_id: Option<i32> = None;
536 let mut user_id: Option<i32> = None;
537 let mut ticket_id: Option<i32> = None;
538 let mut message_id: Option<i32> = None;
539 let mut case_id: Option<i32> = None;
540 let mut announcement_id: Option<i32> = None;
541
542 for (key, raw) in ¶ms {
543 let id = raw.parse::<i32>().map_err(|_| {
544 (StatusCode::BAD_REQUEST, Json(ApiResponse::<Empty>::error(format!("Invalid {}: '{}'. Must be an integer.", key, raw)))).into_response()
545 })?;
546 match key.as_str() {
547 "module_id" => module_id = Some(id),
548 "assignment_id" => assignment_id = Some(id),
549 "task_id" => task_id = Some(id),
550 "submission_id" => submission_id = Some(id),
551 "file_id" => file_id = Some(id),
552 "user_id" => user_id = Some(id),
553 "ticket_id" => ticket_id = Some(id),
554 "case_id" => case_id = Some(id),
555 "announcement_id" => announcement_id = Some(id),
556 "message_id" => message_id = Some(id),
557 _ => return Err((StatusCode::BAD_REQUEST, Json(ApiResponse::<Empty>::error(format!("Unexpected parameter: '{}'.", key)))).into_response()),
558 }
559 }
560
561 if let Some(uid) = user_id {
562 check_user_exists(uid, db).await.map_err(|e| e.into_response())?;
563 }
564 if let Some(mid) = module_id {
565 check_module_exists(mid, db).await.map_err(|e| e.into_response())?;
566 }
567 if let (Some(mid), Some(aid)) = (module_id, assignment_id) {
568 check_assignment_hierarchy(mid, aid, db).await.map_err(|e| e.into_response())?;
569 }
570 if let (Some(mid), Some(aid), Some(tid)) = (module_id, assignment_id, task_id) {
571 check_task_hierarchy(mid, aid, tid, db).await.map_err(|e| e.into_response())?;
572 }
573 if let (Some(mid), Some(aid), Some(sid)) = (module_id, assignment_id, submission_id) {
574 check_submission_hierarchy(mid, aid, sid, db).await.map_err(|e| e.into_response())?;
575 }
576 if let (Some(mid), Some(aid), Some(fid)) = (module_id, assignment_id, file_id) {
577 check_file_hierarchy(mid, aid, fid, db).await.map_err(|e| e.into_response())?;
578 }
579 if let (Some(mid), Some(aid), Some(tid)) = (module_id, assignment_id, ticket_id) {
580 check_ticket_hierarchy(mid, aid, tid, db).await.map_err(|e| e.into_response())?;
581 }
582 if let (Some(mid), Some(aid), Some(sid)) = (module_id, assignment_id, case_id) {
583 check_plagiarism_hierarchy(mid, aid, sid, db).await.map_err(|e| e.into_response())?;
584 }
585
586 if let (Some(mid), Some(ann_id)) = (module_id, announcement_id) {
587 check_announcement_hierarchy(mid, ann_id, db).await.map_err(|e| e.into_response())?;
588 }
589
590 if let (Some(mid), Some(aid), Some(tid), Some(meid)) = (module_id, assignment_id, ticket_id, message_id) {
591 check_message_hierarchy(mid, aid, tid, meid, db).await.map_err(|e| e.into_response())?;
592 }
593
594 Ok(next.run(req).await)
595}
596
597pub async fn require_ticket_ws_access(
599 State(app_state): State<AppState>,
600 Path(params): Path<HashMap<String, String>>,
601 req: axum::http::Request<Body>,
602 next: Next,
603) -> Result<Response, (StatusCode, Json<ApiResponse<Empty>>)> {
604 let db = app_state.db();
605
606 let (req, user) = extract_and_insert_authuser(req).await?;
608
609 let ticket_id = params.get("ticket_id")
611 .and_then(|s| s.parse::<i64>().ok())
612 .ok_or((
613 StatusCode::BAD_REQUEST,
614 Json(ApiResponse::error("Missing or invalid ticket_id")),
615 ))?;
616
617 let ticket = db::models::tickets::Entity::find_by_id(ticket_id)
619 .one(db)
620 .await
621 .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Database error while checking ticket"))))?
622 .ok_or((StatusCode::NOT_FOUND, Json(ApiResponse::error("Ticket not found"))))?;
623
624 if ticket.user_id == user.0.sub {
626 return Ok(next.run(req).await);
627 }
628
629 if user.0.admin {
631 return Ok(next.run(req).await);
632 }
633
634 let assignment = db::models::assignment::Entity::find_by_id(ticket.assignment_id)
636 .one(db)
637 .await
638 .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(ApiResponse::error("Database error while checking assignment"))))?
639 .ok_or((StatusCode::NOT_FOUND, Json(ApiResponse::error("Assignment not found for ticket"))))?;
640
641 let module_id = assignment.module_id;
642
643 if user_has_any_role(db, user.0.sub, module_id, &["Lecturer", "AssistantLecturer", "Tutor"]).await {
645 return Ok(next.run(req).await);
646 }
647
648 Err((StatusCode::FORBIDDEN, Json(ApiResponse::error("Not allowed to access this ticket websocket"))))
650}